How does the “view” method work in PyTorch?

后端 未结 8 2029
别那么骄傲
别那么骄傲 2020-12-07 07:14

I am confused about the method view() in the following code snippet.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__in         


        
相关标签:
8条回答
  • 2020-12-07 07:40

    What is the meaning of parameter -1?

    You can read -1 as dynamic number of parameters or "anything". Because of that there can be only one parameter -1 in view().

    If you ask x.view(-1,1) this will output tensor shape [anything, 1] depending on the number of elements in x. For example:

    import torch
    x = torch.tensor([1, 2, 3, 4])
    print(x,x.shape)
    print("...")
    print(x.view(-1,1), x.view(-1,1).shape)
    print(x.view(1,-1), x.view(1,-1).shape)
    

    Will output:

    tensor([1, 2, 3, 4]) torch.Size([4])
    ...
    tensor([[1],
            [2],
            [3],
            [4]]) torch.Size([4, 1])
    tensor([[1, 2, 3, 4]]) torch.Size([1, 4])
    
    0 讨论(0)
  • 2020-12-07 07:41

    Let's try to understand view by the following examples:

        a=torch.range(1,16)
    
    print(a)
    
        tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
                15., 16.])
    
    print(a.view(-1,2))
    
        tensor([[ 1.,  2.],
                [ 3.,  4.],
                [ 5.,  6.],
                [ 7.,  8.],
                [ 9., 10.],
                [11., 12.],
                [13., 14.],
                [15., 16.]])
    
    print(a.view(2,-1,4))   #3d tensor
    
        tensor([[[ 1.,  2.,  3.,  4.],
                 [ 5.,  6.,  7.,  8.]],
    
                [[ 9., 10., 11., 12.],
                 [13., 14., 15., 16.]]])
    print(a.view(2,-1,2))
    
        tensor([[[ 1.,  2.],
                 [ 3.,  4.],
                 [ 5.,  6.],
                 [ 7.,  8.]],
    
                [[ 9., 10.],
                 [11., 12.],
                 [13., 14.],
                 [15., 16.]]])
    
    print(a.view(4,-1,2))
    
        tensor([[[ 1.,  2.],
                 [ 3.,  4.]],
    
                [[ 5.,  6.],
                 [ 7.,  8.]],
    
                [[ 9., 10.],
                 [11., 12.]],
    
                [[13., 14.],
                 [15., 16.]]])
    

    -1 as an argument value is an easy way to compute the value of say x provided we know values of y, z or the other way round in case of 3d and for 2d again an easy way to compute the value of say x provided we know values of y or vice versa..

    0 讨论(0)
提交回复
热议问题