Proper usage of `tf.scatter_nd` in tensorflow-r1.2

▼魔方 西西 提交于 2021-02-16 19:08:09

问题


Given indices with shape [batch_size, sequence_len], updates with shape [batch_size, sequence_len, sampled_size], to_shape with shape [batch_size, sequence_len, vocab_size], where vocab_size >> sampled_size, I'd like to use tf.scatter to map the updates to a huge tensor with to_shape, such that to_shape[bs, indices[bs, sz]] = updates[bs, sz]. That is, I'd like to map the updates to to_shape row by row. Please note that sequence_len and sampled_size are scalar tensors, while others are fixed. I tried to do the following:

new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)

But I got an error:

ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]

Could you please tell me how to use scatter_nd properly? Thanks in advance!


回答1:


So assuming you have:

  • A tensor updates with shape [batch_size, sequence_len, sampled_size].
  • A tensor indices with shape [batch_size, sequence_len, sampled_size].

Then you do:

import tensorflow as tf

# Create updates and indices...

# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
                     tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)

tf.scatter_nd takes an indices tensor, an updates tensor and some shape. updates is the original tensor, and the shape is just the desired output shape, so [batch_size, sequence_len, vocab_size]. Now, indices is more complicated. Since your output has 3 dimensions (rank 3), for each of the elements in updates you need 3 indices to determine where in the output each element is going to be placed. So the shape of the indices parameter should be the same as updates with an additional dimension of size 3. In this case, we want the first to dimensions to be the same, but we still have to specify the 3 indices. So we use tf.meshgrid to generate the indices that we need and we tile them along the third dimension (the first and second index for each element vector in the last dimension of updates is the same). Finally, we stack these indices with the previously created mapping indices and we have our full 3-dimensional indices.




回答2:


I think you might be looking for this.

def permute_batched_tensor(batched_x, batched_perm_ids):
    indices = tf.tile(tf.expand_dims(batched_perm_ids, 2), [1,1,batched_x.shape[2]])

    # Create additional indices
    i1, i2 = tf.meshgrid(tf.range(batched_x.shape[0]),
                     tf.range(batched_x.shape[2]), indexing="ij")
    i1 = tf.tile(i1[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    i2 = tf.tile(i2[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    # Create final indices
    idx = tf.stack([i1, indices, i2], axis=-1)
    temp = tf.scatter_nd(idx, batched_x, batched_x.shape)
    return temp


来源:https://stackoverflow.com/questions/45162998/proper-usage-of-tf-scatter-nd-in-tensorflow-r1-2

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!