Best way to mimic PyTorch sliced assignment with Keras/Tensorflow

前端 未结 2 689
南旧
南旧 2021-01-21 22:00

I am trying to mimic the operation done in PyTorch below:

vol = Variable(torch.FloatTensor(A, B*2, C, D, E).zero_()).cuda()
for i in range(C):
  if i > 0 :
           


        
2条回答
  •  粉色の甜心
    2021-01-21 22:51

    You need to construct the tensor "by hand". Assuming both input0 and input1 have shape (A, D, E, B), you can do something like this:

    # Make the indexing mask with TensorFlow
    in_shape = tf.shape(input0)
    in_dims = 4
    idx = tf.meshgrid(*[tf.range(in_shape[i]) for i in range(in_dims)], indexing='ij')[2]
    idx = tf.expand_dims(idx, axis=3)
    r = tf.range(C)[tf.newaxis, tf.newaxis, tf.newaxis, :, tf.newaxis]
    mask = idx >= r
    
    # If all dimensions are known at graph construction time, you can instead
    # make the mask with NumPy like this to save graph computation time
    idx = np.meshgrid(*[np.arange(d) for d in (A, D, E, B)], indexing='ij')[2]
    idx = np.expand_dims(idx, 3)
    r = np.arange(C)[np.newaxis, np.newaxis, np.newaxis, :, np.newaxis]
    mask = idx >= r
    
    # Make the tensor
    input0_tile = tf.tile(tf.expand_dims(input0, 3), (1, 1, 1, C, 1))
    input1_tile = tf.tile(tf.expand_dims(input1, 3), (1, 1, 1, C, 1))
    zero_tile = tf.zeros_like(input0_tile)
    vol0 = np.where(mask, input0_tile, zero_tile)
    vol1 = np.where(mask, input1_tile, zero_tile)
    vol = tf.concat([vol0, vol1], axis=-1)
    

    Note that you need either the first or the second block followed by the third block, not the three blocks (see comments). The code builds a binary mask using a tf.meshgrid and a tf.range of indices, then uses tf.where to select values from the inputs or zeros.

提交回复
热议问题