I have a tensor logits with the dimensions [batch_size, num_rows, num_coordinates] (i.e. each logit in the batch is a matrix). In my case batch siz
mrry's answer is great, but I think with the function tf.gather_nd the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])
result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
print(sess.run(result))
This will print
[[[ 10. 10. 20. 20.]
[ 11. 10. 10. 30.]]
[[ 15. 11. 11. 21.]
[ 17. 11. 11. 21.]]]
tf.gather_nd should be available as of v0.10. Check out this github issue for more discussions on this.