I have a tensor of shape (16, 4096, 3)
. I have another tensor of indices of shape (16, 32768, 3)
. I am trying to collect the values along dim
For the last-axis gathering, we can use the 2D-reshape trick for general ND cases, and then employ @LiShaoyuan 2D code above
# last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
def torch_gather(param, id_tensor):
# 2d-gather torch equivalent from @LiShaoyuan above
def gather2d(target, id_tensor):
idx = tf.stack([tf.range(tf.shape(id_tensor)[0]),id_tensor[:,0]],axis=-1)
result = tf.gather_nd(target,idx)
return tf.expand_dims(result,axis=-1)
target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
target_shape = id_tensor.shape
id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
result = gather2d(target, id_tensor)
return tf.reshape(result, target_shape)