TensorFlow, batchwise indexing (first dimension) and sorting

和自甴很熟 提交于 2020-01-02 06:10:14

问题


I've got a params tensor with shape (?,368,5), as well as a query tensor with shape (?,368). The query tensor stores indices for sorting the first tensor.

The required output has shape: (?,368,5). Since I need it for a loss function in a neural network, the used operations should stay differentiable. Also, at runtime the size of the first axis ? corresponds to the batchsize.

So far I experimented with tf.gather and tf.gather_nd, however tf.gather(params,query) results in a tensor with shape (?,368,368,5).

The query tensor is achieved by performing:

query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

Overall, I try to sort the params tensor by the first element on the third axis (for kind of a chamfer distance). At last to mention is, that I work with the Keras framework.


回答1:


You need to add the indices of the first dimension to query in order to use it with tf.gather_nd. Here is a way to do it:

import tensorflow as tf
import numpy as np

np.random.seed(100)

with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.placeholder(tf.float32, [None, 368, 5])
    query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
    n = tf.shape(params)[0]
    # Make tensor of indices for the first dimension
    ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
    # Stack indices
    idx = tf.stack([ii, query], axis=-1)
    # Gather reordered tensor
    result = tf.gather_nd(params, idx)
    # Test
    out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
    # Check the order is correct
    print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
    # True


来源:https://stackoverflow.com/questions/50605059/tensorflow-batchwise-indexing-first-dimension-and-sorting

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!