How does the “view” method work in PyTorch?

后端 未结 8 2056
别那么骄傲
别那么骄傲 2020-12-07 07:14

I am confused about the method view() in the following code snippet.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__in         


        
8条回答
  •  没有蜡笔的小新
    2020-12-07 07:35

    Let's do some examples, from simpler to more difficult.

    1. The view method returns a tensor with the same data as the self tensor (which means that the returned tensor has the same number of elements), but with a different shape. For example:

      a = torch.arange(1, 17)  # a's shape is (16,)
      
      a.view(4, 4) # output below
        1   2   3   4
        5   6   7   8
        9  10  11  12
       13  14  15  16
      [torch.FloatTensor of size 4x4]
      
      a.view(2, 2, 4) # output below
      (0 ,.,.) = 
      1   2   3   4
      5   6   7   8
      
      (1 ,.,.) = 
       9  10  11  12
      13  14  15  16
      [torch.FloatTensor of size 2x2x4]
      
    2. Assuming that -1 is not one of the parameters, when you multiply them together, the result must be equal to the number of elements in the tensor. If you do: a.view(3, 3), it will raise a RuntimeError because shape (3 x 3) is invalid for input with 16 elements. In other words: 3 x 3 does not equal 16 but 9.

    3. You can use -1 as one of the parameters that you pass to the function, but only once. All that happens is that the method will do the math for you on how to fill that dimension. For example a.view(2, -1, 4) is equivalent to a.view(2, 2, 4). [16 / (2 x 4) = 2]

    4. Notice that the returned tensor shares the same data. If you make a change in the "view" you are changing the original tensor's data:

      b = a.view(4, 4)
      b[0, 2] = 2
      a[2] == 3.0
      False
      
    5. Now, for a more complex use case. The documentation says that each new view dimension must either be a subspace of an original dimension, or only span d, d + 1, ..., d + k that satisfy the following contiguity-like condition that for all i = 0, ..., k - 1, stride[i] = stride[i + 1] x size[i + 1]. Otherwise, contiguous() needs to be called before the tensor can be viewed. For example:

      a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
      a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
      
      # The commented line below will raise a RuntimeError, because one dimension
      # spans across two contiguous subspaces
      # a_t.view(-1, 4)
      
      # instead do:
      a_t.contiguous().view(-1, 4)
      
      # To see why the first one does not work and the second does,
      # compare a.stride() and a_t.stride()
      a.stride() # (24, 6, 2, 1)
      a_t.stride() # (24, 2, 1, 6)
      

      Notice that for a_t, stride[0] != stride[1] x size[1] since 24 != 2 x 3

提交回复
热议问题