I am trying to mimic the operation done in PyTorch below:
vol = Variable(torch.FloatTensor(A, B*2, C, D, E).zero_()).cuda()
for i in range(C):
if i > 0 :
You need to construct the tensor "by hand". Assuming both input0
and input1
have shape (A
, D
, E
, B
), you can do something like this:
# Make the indexing mask with TensorFlow
in_shape = tf.shape(input0)
in_dims = 4
idx = tf.meshgrid(*[tf.range(in_shape[i]) for i in range(in_dims)], indexing='ij')[2]
idx = tf.expand_dims(idx, axis=3)
r = tf.range(C)[tf.newaxis, tf.newaxis, tf.newaxis, :, tf.newaxis]
mask = idx >= r
# If all dimensions are known at graph construction time, you can instead
# make the mask with NumPy like this to save graph computation time
idx = np.meshgrid(*[np.arange(d) for d in (A, D, E, B)], indexing='ij')[2]
idx = np.expand_dims(idx, 3)
r = np.arange(C)[np.newaxis, np.newaxis, np.newaxis, :, np.newaxis]
mask = idx >= r
# Make the tensor
input0_tile = tf.tile(tf.expand_dims(input0, 3), (1, 1, 1, C, 1))
input1_tile = tf.tile(tf.expand_dims(input1, 3), (1, 1, 1, C, 1))
zero_tile = tf.zeros_like(input0_tile)
vol0 = np.where(mask, input0_tile, zero_tile)
vol1 = np.where(mask, input1_tile, zero_tile)
vol = tf.concat([vol0, vol1], axis=-1)
Note that you need either the first or the second block followed by the third block, not the three blocks (see comments). The code builds a binary mask using a tf.meshgrid and a tf.range of indices, then uses tf.where to select values from the inputs or zeros.