How to shift values in tensor

╄→гoц情女王★ 提交于 2021-01-28 08:21:21

问题


I have tensor T of shape [batch_size, A] with values and tensor S of shape [batch_size] with shift parameters.

I would like to shift values in T[b] by S[b] positions to the right, the last S[b] elements of T[b] should be dropped and new elements should be set to 0.

So basically want to do something like:

for i in range(batch_size):
  T[i] = zeros[:S[i]] + T[i, :A-S[i]]

Example:

For:
T = [[1, 2, 3], [4, 5, 6]]
S = [1, 2]

Return:
T' = [[0, 1, 2], [0, 0, 4]]

Is there some easy way to do it?


回答1:


You can use tf.concat and tf.stack for that purpose:

T_shift = tf.zeros((batch_size, A), tf.float32)
tmp = []

for i in xrange(batch_size):
    tmp.append(tf.concat([T_shift[i, :S[i, 0]],T[i, :17 - S[i,0]]], axis = 0))
T_shift = tf.stack(tmp)



回答2:


If you are working in Tensorflow 2, you can use the tf.roll for that purpose:

"The elements are shifted positively (towards larger indices) by the offset of shift along the dimension of axis. Negative shift values will shift elements in the opposite direction. Elements that roll passed the last position will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified."

tf.roll(
       input, shift, axis, name=None
)

# 't' is [0, 1, 2, 3, 4]
roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]

# shifting along multiple dimensions
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]

# shifting along the same axis multiple times
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]


来源:https://stackoverflow.com/questions/48215077/how-to-shift-values-in-tensor

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!