Pytorch Linear Layer now automatically reshape the input?

纵饮孤独 提交于 2021-01-28 11:47:27

问题


I remember in the past, nn.Linear only accepts 2D tensors.

But today, I discover that nn.Linear now accepts 3D, or even tensors with arbitrary dimensions.

X = torch.randn((20,20,20,20,10))
linear_layer = nn.Linear(10,5)
output = linear_layer(X)
print(output.shape)
>>> torch.Size([20, 20, 20, 20, 5])

When I check the documentation for Pytorch, it does say that it now takes

Input: :math:(N, *, H_{in}) where :math:* means any number of additional dimensions and :math:H_{in} = \text{in\_features}

So it seems to me that Pytorch nn.Linear now reshape the input by x.view(-1, input_dim) automatically.

But I cannot find any x.shape or x.view in the source code:

class Linear(Module):
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

Can anyone confirms this?


回答1:


torch.nn.Linear uses torch.nn.functional.linear function under the hood, that's where the operations are taking places (see documentation).

It looks like this (removed docstrings and decorators for brevity):

def linear(input, weight, bias=None):
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret

First case is addmm, which implements beta*mat + alpha*(mat1 @ mat2) and is supposedly faster (see here for example).

Second operation is matmul, and as one can read in their docs it performs various operations based on the shape of tensors provided (five cases, not going to copy them blatantly here).

In summary it preserves dimensions between first batch and last features dimension. No view() is used whatsoever, especially not this x.view(-1, input_dim), check the code below:

import torch

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)

print(torch.matmul(tensor1, tensor2).shape)
print(torch.matmul(tensor1, tensor2).view(-1, tensor1.shape[1]).shape)

which gives:

torch.Size([10, 3, 5]) # preserves input's 3
torch.Size([50, 3]) # destroys the batch even


来源:https://stackoverflow.com/questions/57138540/pytorch-linear-layer-now-automatically-reshape-the-input

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!