问题
I remember in the past, nn.Linear only accepts 2D tensors.
But today, I discover that nn.Linear now accepts 3D, or even tensors with arbitrary dimensions.
X = torch.randn((20,20,20,20,10))
linear_layer = nn.Linear(10,5)
output = linear_layer(X)
print(output.shape)
>>> torch.Size([20, 20, 20, 20, 5])
When I check the documentation for Pytorch, it does say that it now takes
Input: :math:
(N, *, H_{in})where :math:*means any number of additional dimensions and :math:H_{in} = \text{in\_features}
So it seems to me that Pytorch nn.Linear now reshape the input by x.view(-1, input_dim) automatically.
But I cannot find any x.shape or x.view in the source code:
class Linear(Module):
__constants__ = ['bias']
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
@weak_script_method
def forward(self, input):
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
Can anyone confirms this?
回答1:
torch.nn.Linear uses torch.nn.functional.linear function under the hood, that's where the operations are taking places (see documentation).
It looks like this (removed docstrings and decorators for brevity):
def linear(input, weight, bias=None):
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
First case is addmm, which implements beta*mat + alpha*(mat1 @ mat2) and is supposedly faster (see here for example).
Second operation is matmul, and as one can read in their docs it performs various operations based on the shape of tensors provided (five cases, not going to copy them blatantly here).
In summary it preserves dimensions between first batch and last features dimension. No view() is used whatsoever, especially not this x.view(-1, input_dim), check the code below:
import torch
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
print(torch.matmul(tensor1, tensor2).shape)
print(torch.matmul(tensor1, tensor2).view(-1, tensor1.shape[1]).shape)
which gives:
torch.Size([10, 3, 5]) # preserves input's 3
torch.Size([50, 3]) # destroys the batch even
来源:https://stackoverflow.com/questions/57138540/pytorch-linear-layer-now-automatically-reshape-the-input