How does Pytorch's “Fold” and “Unfold” work?

前端 未结 2 1458
孤城傲影
孤城傲影 2020-12-04 01:21

I\'ve gone through the official doc. I\'m having a hard time understanding what this function is used for and how it works. Can someone explain this in Layman terms?

<
2条回答
  •  猫巷女王i
    2020-12-04 02:17

    One dimensional unfolding is easy:

    x = torch.arange(1, 9).float()
    print(x)
    # dimension, size, step
    print(x.unfold(0, 2, 1))
    print(x.unfold(0, 3, 2))
    

    Out:

    tensor([1., 2., 3., 4., 5., 6., 7., 8.])
    tensor([[1., 2.],
            [2., 3.],
            [3., 4.],
            [4., 5.],
            [5., 6.],
            [6., 7.],
            [7., 8.]])
    tensor([[1., 2., 3.],
            [3., 4., 5.],
            [5., 6., 7.]])
    

    Two dimensional unfolding (also called patching)

    import torch
    patch=(3,3)
    x=torch.arange(16).float()
    print(x, x.shape)
    x2d = x.reshape(1,1,4,4)
    print(x2d, x2d.shape)
    h,w = patch
    c=x2d.size(1)
    print(c) # channels
    # unfold(dimension, size, step)
    r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
    print(r.shape)
    print(r) # result
    
    tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
            14., 15.]) torch.Size([16])
    tensor([[[[ 0.,  1.,  2.,  3.],
              [ 4.,  5.,  6.,  7.],
              [ 8.,  9., 10., 11.],
              [12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
    1
    torch.Size([4, 1, 3, 3])
    
    tensor([[[[ 0.,  1.,  2.],
              [ 4.,  5.,  6.],
              [ 8.,  9., 10.]]],
    
    
            [[[ 4.,  5.,  6.],
              [ 8.,  9., 10.],
              [12., 13., 14.]]],
    
    
            [[[ 1.,  2.,  3.],
              [ 5.,  6.,  7.],
              [ 9., 10., 11.]]],
    
    
            [[[ 5.,  6.,  7.],
              [ 9., 10., 11.],
              [13., 14., 15.]]]])
    

提交回复
热议问题