Tensorflow python : Accessing individual elements in a tensor

前端 未结 4 1876
日久生厌
日久生厌 2020-12-13 17:09

This question is with respect to accessing individual elements in a tensor, say [[1,2,3]]. I need to access the inner element [1,2,3] (This can be performed using .eval() or

4条回答
  •  轮回少年
    2020-12-13 17:46

    There are two main ways to access subsets of the elements in a tensor, either of which should work for your example.

    1. Use the indexing operator (based on tf.slice()) to extract a contiguous slice from the tensor.

      input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      
      output = input[0, :]
      print sess.run(output)  # ==> [1 2 3]
      

      The indexing operator supports many of the same slice specifications as NumPy does.

    2. Use the tf.gather() op to select a non-contiguous slice from the tensor.

      input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      
      output = tf.gather(input, 0)
      print sess.run(output)  # ==> [1 2 3]
      
      output = tf.gather(input, [0, 2])
      print sess.run(output)  # ==> [[1 2 3] [7 8 9]]
      

      Note that tf.gather() only allows you to select whole slices in the 0th dimension (whole rows in the example of a matrix), so you may need to tf.reshape() or tf.transpose() your input to obtain the appropriate elements.

提交回复
热议问题