How to apply a custom function to specific columns in a matrix in PyTorch

人盡茶涼 提交于 2019-12-05 07:55:21

You can map the stripe function over the first dimension of your tensor using torch.unbind as

In [1]: import torch

In [2]: def strip(a):
   ...:     i, j = a.size()
   ...:     assert(i >= j)
   ...:     out = torch.zeros((i - j + 1, j))
   ...:     for diag in range(0, i - j + 1):
   ...:         out[diag] = torch.diag(a, -diag)
   ...:     return out
   ...: 
   ...: 

In [3]: a = torch.randn((182, 91)).cuda()

In [5]: output = strip(a)

In [6]: output.size()
Out[6]: torch.Size([92, 91])

In [7]: a = torch.randn((150, 182, 91))

In [8]: output = list(map(strip, torch.unbind(a, 0)))

In [9]: output = torch.stack(output, 0)

In [10]: output.size()
Out[10]: torch.Size([150, 92, 91])

Here is a way to do this without using stack and unbind, by computing the diagonal stripe directly on a batch matrix:

def batch_stripe(a):
    b, i, j = a.size()
    assert i > j
    b_s, k, l = a.stride()
    return torch.as_strided(a, (b, i - j, j), (b_s, k, k+1))

For more, refer to: https://discuss.pytorch.org/t/optimizing-diagonal-stripe-code/17777/5

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!