If A
is a TensorFlow variable like so
A = tf.Variable([[1, 2], [3, 4]])
and index
is another variable
I couldn't get the accepted answer to work in Tensorflow 2 when I incorporated it into a loss function. Something about GradientTape didn't like it. My solution is an altered version of the accepted answer:
def get_rows(arr):
N, _ = arr.shape
return N
num_rows= tf.py_function(get_rows, [arr], [tf.int32])[0]
rng = tf.range(0,num_rows)
ind = tf.stack([rng, ind], axis=1)
tf.gather_nd(arr, ind)