TensorFlow: using a tensor to index another tensor

后端 未结 1 2010
挽巷
挽巷 2020-12-17 09:21

I have a basic question about how to do indexing in TensorFlow.

In numpy:

x = np.asarray([1,2,3,3,2,5,6,7,1,3])
e = np.asarray([0,1,0,1,1,1,0,1])
#nu         


        
相关标签:
1条回答
  • 2020-12-17 09:41

    Fortunately, the exact case you're asking about is supported in TensorFlow by tf.gather():

    result = x_t * tf.gather(e_t, x_t)
    
    with tf.Session() as sess:
        print sess.run(result)  # ==> 'array([1, 0, 3, 3, 0, 5, 0, 7, 1, 3])'
    

    The tf.gather() op is less powerful than NumPy's advanced indexing: it only supports extracting full slices of a tensor on its 0th dimension. Support for more general indexing has been requested, and is being tracked in this GitHub issue.

    0 讨论(0)
提交回复
热议问题