PyTorch - contiguous()

前端 未结 6 2245
春和景丽
春和景丽 2020-12-22 15:40

I was going through this example of a LSTM language model on github (link). What it does in general is pretty clear to me. But I\'m still struggling to understand what calli

6条回答
  •  悲&欢浪女
    2020-12-22 16:20

    The accepted answers was so great, and I tried to dupe transpose() function effect. I created the two functions that can check the samestorage() and the contiguous.

    def samestorage(x,y):
        if x.storage().data_ptr()==y.storage().data_ptr():
            print("same storage")
        else:
            print("different storage")
    def contiguous(y):
        if True==y.is_contiguous():
            print("contiguous")
        else:
            print("non contiguous")
    

    I checked and got this result as a table:

    You can review the checker code down below, but let's give one example when the tensor is non contiguous. We cannot simple call view() on that tensor, we would need to reshape() it or we could also call .contiguous().view().

    x = torch.randn(3,2)
    y = x.transpose(0, 1)
    y.view(6) # RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
      
    x = torch.randn(3,2)
    y = x.transpose(0, 1)
    y.reshape(6)
    
    x = torch.randn(3,2)
    y = x.transpose(0, 1)
    y.contiguous().view(6)
    

    Further to note there are methods that create contiguous and non contiguous tensors in the end. There are methods that can operate on a same storage, and some methods as flip() that will create a new storage (read: clone the tensor) before return.

    The checker code:

    import torch
    x = torch.randn(3,2)
    y = x.transpose(0, 1) # flips two axes
    print("\ntranspose")
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\nnarrow")
    x = torch.randn(3,2)
    y = x.narrow(0, 1, 2) #dim, start, len  
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\npermute")
    x = torch.randn(3,2)
    y = x.permute(1, 0) # sets the axis order
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\nview")
    x = torch.randn(3,2)
    y=x.view(2,3)
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\nreshape")
    x = torch.randn(3,2)
    y = x.reshape(6,1)
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\nflip")
    x = torch.randn(3,2)
    y = x.flip(0)
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y)
    
    print("\nexpand")
    x = torch.randn(3,2)
    y = x.expand(2,-1,-1)
    print(x)
    print(y)
    contiguous(y)
    samestorage(x,y) 
    

提交回复
热议问题