In Tensorflow, how to use tf.gather() for the last dimension?

后端 未结 8 1971
余生分开走
余生分开走 2021-01-04 02:54

I am trying to gather slices of a tensor in terms of the last dimension for partial connection between layers. Because the output tensor\'s shape is [batch_size, h, w,

8条回答
  •  误落风尘
    2021-01-04 03:21

    Tensor doesn't have attribute shape, but get_shape() method. Below is runnable by Python 2.7

    import tensorflow as tf
    import numpy as np
    x = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])
    idx = tf.constant([1, 0, 2])
    idx_flattened = tf.range(0, x.get_shape()[0]) * x.get_shape()[1] + idx
    y = tf.gather(tf.reshape(x, [-1]),  # flatten input
                  idx_flattened)  # use flattened indices
    
    with tf.Session(''):
      print y.eval()  # [2 4 9]
    

提交回复
热议问题