how to flatten input in `nn.Sequential` in Pytorch

a 夏天 提交于 2019-12-23 09:57:44

问题


how to flatten input inside the nn.Sequential

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

回答1:


You can create a new module/class as below and use it in the sequential as you are using other modules (call Flatten()).

class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

Ref: https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983




回答2:


As being defined flatten method

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

is speed comparable to view(), but reshape is even faster.

import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

flatten = Flatten()

t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)


#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)

Speed check

# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

If we would use class from above

flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)


5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

This result shows creating a class would be slower approach. This is why it is faster to flatten tensors inside forward. I think this is the main reason they haven't promoted nn.Flatten.

So my suggestion would be to use inside forward for speed. Something like this:

out = inp.reshape(inp.size(0), -1)


来源:https://stackoverflow.com/questions/53953460/how-to-flatten-input-in-nn-sequential-in-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!