I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim
This "should" be a general solution using tf.gather_nd (I've only tested for rank 2 and 3 tensors along the last axis):
def torch_gather(x, indices, gather_axis):
# if pytorch gather indices are
# [[[0, 10, 20], [0, 10, 20], [0, 10, 20]],
# [[0, 10, 20], [0, 10, 20], [0, 10, 20]]]
# tf nd_gather needs to be
# [[0,0,0], [0,0,10], [0,0,20], [0,1,0], [0,1,10], [0,1,20], [0,2,0], [0,2,10], [0,2,20],
# [1,0,0], [1,0,10], [1,0,20], [1,1,0], [1,1,10], [1,1,20], [1,2,0], [1,2,10], [1,2,20]]
# create a tensor containing indices of each element
all_indices = tf.where(tf.fill(indices.shape, True))
gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
# splice in our pytorch style index at the correct axis
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(gather_locations)
else:
gather_indices.append(all_indices[:, axis])
gather_indices = tf.stack(gather_indices, axis=-1)
gathered = tf.gather_nd(x, gather_indices)
reshaped = tf.reshape(gathered, indices.shape)
return reshaped