PyTorch - contiguous()

前端 未结 6 2242
春和景丽
春和景丽 2020-12-22 15:40

I was going through this example of a LSTM language model on github (link). What it does in general is pretty clear to me. But I\'m still struggling to understand what calli

6条回答
  •  爱一瞬间的悲伤
    2020-12-22 16:34

    tensor.contiguous() will create a copy of the tensor, and the element in the copy will be stored in the memory in a contiguous way. The contiguous() function is usually required when we first transpose() a tensor and then reshape (view) it. First, let's create a contiguous tensor:

    aaa = torch.Tensor( [[1,2,3],[4,5,6]] )
    print(aaa.stride())
    print(aaa.is_contiguous())
    #(3,1)
    #True
    

    The stride() return (3,1) means that: when moving along the first dimension by each step (row by row), we need to move 3 steps in the memory. When moving along the second dimension (column by column), we need to move 1 step in the memory. This indicates that the elements in the tensor are stored contiguously.

    Now we try apply come functions to the tensor:

    bbb = aaa.transpose(0,1)
    print(bbb.stride())
    print(bbb.is_contiguous())
    
    #(1, 3)
    #False
    
    
    ccc = aaa.narrow(1,1,2)   ## equivalent to matrix slicing aaa[:,1:3]
    print(ccc.stride())
    print(ccc.is_contiguous())
    
    #(3, 1)
    #False
    
    
    ffffd = aaa.repeat(2,1)   # The first dimension repeat once, the second dimension repeat twice
    print(ffffd.stride())
    print(ffffd.is_contiguous())
    
    #(3, 1)
    #True
    
    
    ## expand is different from repeat.
    ## if a tensor has a shape [d1,d2,1], it can only be expanded using "expand(d1,d2,d3)", which
    ## means the singleton dimension is repeated d3 times
    eee = aaa.unsqueeze(2).expand(2,3,3)
    print(eee.stride())
    print(eee.is_contiguous())
    
    #(3, 1, 0)
    #False
    
    
    fff = aaa.unsqueeze(2).repeat(1,1,8).view(2,-1,2)
    print(fff.stride())
    print(fff.is_contiguous())
    
    #(24, 2, 1)
    #True
    

    Ok, we can find that transpose(), narrow() and tensor slicing, and expand() will make the generated tensor not contiguous. Interestingly, repeat() and view() does not make it discontiguous. So now the question is: what happens if I use a discontiguous tensor?

    The answer is it the view() function cannot be applied to a discontiguous tensor. This is probably because view() requires that the tensor to be contiguously stored so that it can do fast reshape in memory. e.g:

    bbb.view(-1,3)
    

    we will get the error:

    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
     in ()
    ----> 1 bbb.view(-1,3)
    
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /pytorch/aten/src/TH/generic/THTensor.cpp:203
    

    To solve this, simply add contiguous() to a discontiguous tensor, to create contiguous copy and then apply view()

    bbb.contiguous().view(-1,3)
    #tensor([[1., 4., 2.],
            [5., 3., 6.]])
    

提交回复
热议问题