How does Pytorch's “Fold” and “Unfold” work?

前端 未结 2 1453
孤城傲影
孤城傲影 2020-12-04 01:21

I\'ve gone through the official doc. I\'m having a hard time understanding what this function is used for and how it works. Can someone explain this in Layman terms?

<
相关标签:
2条回答
  • 2020-12-04 02:08

    unfold and fold are used to facilitate "sliding window" operation (like convolutions).
    Suppose you want to apply a function foo to every 5x5 window in a feature map/image:

    from torch.nn import functional as f
    windows = f.unfold(x, kernel_size=5)
    

    Now windows has size of batch-(5*5*x.size(1))-num_windows, you can apply foo on windows:

    processed = foo(windows)
    

    Now you need to "fold" processed back to the original size of x:

    out = f.fold(processed, x.shape[-2:], kernel_size=5)
    

    You need to take care of padding, and kernel_size that may affect your ability to "fold" back processed to the size of x.
    Moreover, fold sums over overlapping elements, so you might want to divide the output of fold by patch size.

    0 讨论(0)
  • 2020-12-04 02:17

    One dimensional unfolding is easy:

    x = torch.arange(1, 9).float()
    print(x)
    # dimension, size, step
    print(x.unfold(0, 2, 1))
    print(x.unfold(0, 3, 2))
    

    Out:

    tensor([1., 2., 3., 4., 5., 6., 7., 8.])
    tensor([[1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.],
            [5., 6.],
            [6., 7.],
            [7., 8.]])
    tensor([[1., 2., 3.],
            [3., 4., 5.],
            [5., 6., 7.]])
    

    Two dimensional unfolding (also called patching)

    import torch
    patch=(3,3)
    x=torch.arange(16).float()
    print(x, x.shape)
    x2d = x.reshape(1,1,4,4)
    print(x2d, x2d.shape)
    h,w = patch
    c=x2d.size(1)
    print(c) # channels
    # unfold(dimension, size, step)
    r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
    print(r.shape)
    print(r) # result
    
    tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
            14., 15.]) torch.Size([16])
    tensor([[[[ 0.,  1.,  2.,  3.],
              [ 4.,  5.,  6.,  7.],
              [ 8.,  9., 10., 11.],
              [12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
    1
    torch.Size([4, 1, 3, 3])
    
    tensor([[[[ 0.,  1.,  2.],
              [ 4.,  5.,  6.],
              [ 8.,  9., 10.]]],
    
    
            [[[ 4.,  5.,  6.],
              [ 8.,  9., 10.],
              [12., 13., 14.]]],
    
    
            [[[ 1.,  2.,  3.],
              [ 5.,  6.,  7.],
              [ 9., 10., 11.]]],
    
    
            [[[ 5.,  6.,  7.],
              [ 9., 10., 11.],
              [13., 14., 15.]]]])
    

    0 讨论(0)
提交回复
热议问题