问题 I'm trying to get the rows of a 3D tensor in a specific order of indices. Here are the inputs: import tensorflow as tf matrix = tf.constant([ [[0, 1], [2, 3], [4, 5], [6, 7]], [[8, 9], [10, 11], [12, 13], [14, 15]], [[16, 17], [18, 19], [20, 21], [22, 23]], [[24, 25], [26, 27], [28, 29], [30, 31]], [[32, 33], [34, 35], [36, 37], [38, 39]] ]) indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]]) # required output tensor: [[[6, 7], [4, 5], [2, 3], [0, 1]], [[8, 9], [10, 11