tensorflow equivalent of torch.gather

前端 未结 3 1111
花落未央
花落未央 2021-01-16 05:44

I have a tensor of shape (16, 4096, 3). I have another tensor of indices of shape (16, 32768, 3). I am trying to collect the values along dim

3条回答
  •  甜味超标
    2021-01-16 06:08

    For the last-axis gathering, we can use the 2D-reshape trick for general ND cases, and then employ @LiShaoyuan 2D code above

            # last-axis gathering only - use 2D-reshape-trick for Torch's style nD gathering
            def torch_gather(param, id_tensor):
    
                # 2d-gather torch equivalent from @LiShaoyuan above 
                def gather2d(target, id_tensor):
                    idx = tf.stack([tf.range(tf.shape(id_tensor)[0]),id_tensor[:,0]],axis=-1)
                    result = tf.gather_nd(target,idx)
                    return tf.expand_dims(result,axis=-1)
    
                target = tf.reshape(param, (-1, param.shape[-1])) # reshape 2D
                target_shape = id_tensor.shape
    
                id_tensor = tf.reshape(id_tensor, (-1, 1)) # also 2D-index
                result = gather2d(target, id_tensor)
                return tf.reshape(result, target_shape)
    

提交回复
热议问题