Filling torch tensor with zeros after certain index

二次信任 提交于 2020-04-14 08:24:12

问题


Given a 3d tenzor, say: batch x sentence length x embedding dim

a = torch.rand((10, 1000, 96)) 

and an array(or tensor) of actual lengths for each sentence

lengths =  torch .randint(1000,(10,))

outputs tensor([ 370., 502., 652., 859., 545., 964., 566., 576.,1000., 803.])

How to fill tensor ‘a’ with zeros after certain index along dimension 1 (sentence length) according to tensor ‘lengths’ ?

I want smth like that :

a[ : , lengths : , : ]  = 0

One way of doing it (slow if batch size is big enough):

for i_batch in range(10):
    a[ i_batch  , lengths[i_batch ] : , : ]  = 0

回答1:


You can do it using a binary mask.
Using lengths as column-indices to mask we indicate where each sequence ends (note that we make mask longer than a.size(1) to allow for sequences with full length).
Using cumsum() we set all entries in mask after the seq len to 1.

mask = torch.zeros(a.shape[0], a.shape[1] + 1, dtype=a.dtype, device=a.device)
mask[(torch.arange(a.shape[0], lengths)] = 1
mask = mask.cumsum(dim=1)[:, :-1]  # remove the superfluous column
a = a * (1. - mask[..., None])     # use mask to zero after each column

For a.shape = (10, 5, 96), and lengths = [1, 2, 1, 1, 3, 0, 4, 4, 1, 3].
Assigning 1 to respective lengths at each row, mask looks like:

mask = 
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.]])

After cumsum you get

mask = 
tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1.]])

Note that it exactly has zeros where the valid sequence entries are and ones beyond the lengths of the sequences. Taking 1 - mask gives you exactly what you want.

Enjoy ;)



来源:https://stackoverflow.com/questions/57548180/filling-torch-tensor-with-zeros-after-certain-index

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!