问题
I've got a params tensor with shape (?,368,5), as well as a query tensor with shape (?,368). The query tensor stores indices for sorting the first tensor.
The required output has shape: (?,368,5). Since I need it for a loss function in a neural network, the used operations should stay differentiable. Also, at runtime the size of the first axis ? corresponds to the batchsize.
So far I experimented with tf.gather and tf.gather_nd, however
tf.gather(params,query) results in a tensor with shape (?,368,368,5).
The query tensor is achieved by performing:
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
Overall, I try to sort the params tensor by the first element on the third axis (for kind of a chamfer distance). At last to mention is, that I work with the Keras framework.
回答1:
You need to add the indices of the first dimension to query in order to use it with tf.gather_nd. Here is a way to do it:
import tensorflow as tf
import numpy as np
np.random.seed(100)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.placeholder(tf.float32, [None, 368, 5])
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
n = tf.shape(params)[0]
# Make tensor of indices for the first dimension
ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
# Stack indices
idx = tf.stack([ii, query], axis=-1)
# Gather reordered tensor
result = tf.gather_nd(params, idx)
# Test
out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
# Check the order is correct
print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
# True
来源:https://stackoverflow.com/questions/50605059/tensorflow-batchwise-indexing-first-dimension-and-sorting