Tensorflow python : Accessing individual elements in a tensor

前端 未结 4 1874
日久生厌
日久生厌 2020-12-13 17:09

This question is with respect to accessing individual elements in a tensor, say [[1,2,3]]. I need to access the inner element [1,2,3] (This can be performed using .eval() or

相关标签:
4条回答
  • 2020-12-13 17:46

    There are two main ways to access subsets of the elements in a tensor, either of which should work for your example.

    1. Use the indexing operator (based on tf.slice()) to extract a contiguous slice from the tensor.

      input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      
      output = input[0, :]
      print sess.run(output)  # ==> [1 2 3]
      

      The indexing operator supports many of the same slice specifications as NumPy does.

    2. Use the tf.gather() op to select a non-contiguous slice from the tensor.

      input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      
      output = tf.gather(input, 0)
      print sess.run(output)  # ==> [1 2 3]
      
      output = tf.gather(input, [0, 2])
      print sess.run(output)  # ==> [[1 2 3] [7 8 9]]
      

      Note that tf.gather() only allows you to select whole slices in the 0th dimension (whole rows in the example of a matrix), so you may need to tf.reshape() or tf.transpose() your input to obtain the appropriate elements.

    0 讨论(0)
  • 2020-12-13 17:53

    I suspect it's the rest of the computation that takes time, rather than accessing one element.

    Also the result might require a copy from whatever memory is stored in, so if it's on the graphics card it will need to be copied back to RAM first and then you get access to your element. If this is the case you might skip it by adding an tensorflow operation to take the first element, and only return that.

    0 讨论(0)
  • 2020-12-13 18:02

    You simply can't get value of 0th element of [[1,2,3]] without run()-ning or eval()-ing an operation which would be getting it. Because before you 'run' or 'eval', you have only a description how to get this inner element(because TF uses symbolic graphs/calculations). So even if you would use tf.gather/tf.slice, you still would have to get values of these operations via eval/run. See @mrry's answer.

    0 讨论(0)
  • 2020-12-13 18:07

    I hope I understood your question well. You can access elements in a tensor in TensorFlow 2 via .numpy().

    import tensorflow as tf
    t = tf.constant([[1,2,3]])
    
    print(t.numpy()[0][1]) # This will prints 2
    

    >>> 2
    
    0 讨论(0)
提交回复
热议问题