This question is with respect to accessing individual elements in a tensor, say [[1,2,3]]. I need to access the inner element [1,2,3] (This can be performed using .eval() or
There are two main ways to access subsets of the elements in a tensor, either of which should work for your example.
Use the indexing operator (based on tf.slice()) to extract a contiguous slice from the tensor.
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = input[0, :]
print sess.run(output) # ==> [1 2 3]
The indexing operator supports many of the same slice specifications as NumPy does.
Use the tf.gather() op to select a non-contiguous slice from the tensor.
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
output = tf.gather(input, 0)
print sess.run(output) # ==> [1 2 3]
output = tf.gather(input, [0, 2])
print sess.run(output) # ==> [[1 2 3] [7 8 9]]
Note that tf.gather()
only allows you to select whole slices in the 0th dimension (whole rows in the example of a matrix), so you may need to tf.reshape() or tf.transpose() your input to obtain the appropriate elements.