问题
I have tensor T of shape [batch_size, A] with values and tensor S of shape [batch_size] with shift parameters.
I would like to shift values in T[b] by S[b] positions to the right, the last S[b] elements of T[b] should be dropped and new elements should be set to 0.
So basically want to do something like:
for i in range(batch_size):
T[i] = zeros[:S[i]] + T[i, :A-S[i]]
Example:
For:
T = [[1, 2, 3], [4, 5, 6]]
S = [1, 2]
Return:
T' = [[0, 1, 2], [0, 0, 4]]
Is there some easy way to do it?
回答1:
You can use tf.concat and tf.stack for that purpose:
T_shift = tf.zeros((batch_size, A), tf.float32)
tmp = []
for i in xrange(batch_size):
tmp.append(tf.concat([T_shift[i, :S[i, 0]],T[i, :17 - S[i,0]]], axis = 0))
T_shift = tf.stack(tmp)
回答2:
If you are working in Tensorflow 2, you can use the tf.roll for that purpose:
"The elements are shifted positively (towards larger indices) by the offset of shift along the dimension of axis. Negative shift values will shift elements in the opposite direction. Elements that roll passed the last position will wrap around to the first and vice versa. Multiple shifts along multiple axes may be specified."
tf.roll(
input, shift, axis, name=None
)
# 't' is [0, 1, 2, 3, 4]
roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]
# shifting along multiple dimensions
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]
# shifting along the same axis multiple times
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]
来源:https://stackoverflow.com/questions/48215077/how-to-shift-values-in-tensor