How to select rows from a 3-D Tensor in TensorFlow?

后端 未结 2 1665
执念已碎
执念已碎 2020-11-30 13:18

I have a tensor logits with the dimensions [batch_size, num_rows, num_coordinates] (i.e. each logit in the batch is a matrix). In my case batch siz

相关标签:
2条回答
  • 2020-11-30 13:39

    This is possible in TensorFlow, but slightly inconvenient, because tf.gather() currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather():

    logits = ... # [2 x 4 x 4] tensor
    indices = tf.constant([[0, 1], [1, 3]])
    
    # Use tf.shape() to make this work with dynamic shapes.
    batch_size = tf.shape(logits)[0]
    rows_per_batch = tf.shape(logits)[1]
    indices_per_batch = tf.shape(indices)[1]
    
    # Offset to add to each row in indices. We use `tf.expand_dims()` to make 
    # this broadcast appropriately.
    offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)
    
    # Convert indices and logits into appropriate form for `tf.gather()`. 
    flattened_indices = tf.reshape(indices + offset, [-1])
    flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))
    
    selected_rows = tf.gather(flattened_logits, flattened_indices)
    
    result = tf.reshape(selected_rows,
                        tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                      tf.shape(logits)[2:]]))
    

    Note that, since this uses tf.reshape() and not tf.transpose(), it doesn't need to modify the (potentially large) data in the logits tensor, so it should be fairly efficient.

    0 讨论(0)
  • 2020-11-30 13:50

    mrry's answer is great, but I think with the function tf.gather_nd the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):

    logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                          [11.0, 10.0, 10.0, 30.0],
                          [12.0, 10.0, 10.0, 20.0],
                          [13.0, 10.0, 10.0, 20.0]],
                         [[14.0, 11.0, 21.0, 31.0],
                          [15.0, 11.0, 11.0, 21.0],
                          [16.0, 11.0, 11.0, 21.0],
                          [17.0, 11.0, 11.0, 21.0]]])
    
    indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])
    
    result = tf.gather_nd(logits, indices)
    with tf.Session() as sess:
        print(sess.run(result))
    

    This will print

    [[[ 10.  10.  20.  20.]
      [ 11.  10.  10.  30.]]
    
     [[ 15.  11.  11.  21.]
      [ 17.  11.  11.  21.]]]
    

    tf.gather_nd should be available as of v0.10. Check out this github issue for more discussions on this.

    0 讨论(0)
提交回复
热议问题