TensorFlow getting elements of every row for specific columns

后端 未结 5 767
再見小時候
再見小時候 2021-02-08 02:57

If A is a TensorFlow variable like so

A = tf.Variable([[1, 2], [3, 4]])

and index is another variable



        
5条回答
  •  不要未来只要你来
    2021-02-08 03:59

    I couldn't get the accepted answer to work in Tensorflow 2 when I incorporated it into a loss function. Something about GradientTape didn't like it. My solution is an altered version of the accepted answer:

    def get_rows(arr):
      N, _ = arr.shape 
      return N
    
    num_rows= tf.py_function(get_rows, [arr], [tf.int32])[0]
    rng = tf.range(0,num_rows)
    ind = tf.stack([rng, ind], axis=1)
    tf.gather_nd(arr, ind)
    

提交回复
热议问题