Pytorch: Understand how nn.Module class internally work

 ̄綄美尐妖づ 提交于 2020-04-13 17:58:09

问题


Generally, a nn.Module can be inherited by a subclass as below.

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

My 1st question is, why I can simply run the code below even my __init__ doesn't have any positinoal arguments for training_signals and it looks like that training_signals is passed to forward() method. How does it work?

model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)

The second question is that how does self.apply(init_weights) internally work? Is it executed before calling forward method?


回答1:


Q1: Why I can simply run the code below even my __init__ doesn't have any positional arguments for training_signals and it looks like that training_signals is passed to forward() method. How does it work?

First, the __init__ is called when you run this line:

model = LinearRegression()

As you can see, you pass no parameters, and you shouldn't. The signature of your __init__ is the same as the one of the base class (which you call when you run super(LinearRegression, self).__init__()). As you can see here, nn.Module's init signature is simply def __init__(self) (just like yours).

Second, model is now an object. When you run the line below:

model(training_signals)

You are actually calling the __call__ method and passing training_signals as a positional parameter. As you can see here, among many other things, the __call__ method calls the forward method:

result = self.forward(*input, **kwargs)

passing all parameters (positional and named) of the __call__ to the forward.

Q2: How does self.apply(init_weights) internally work? Is it executed before calling forward method?

PyTorch is Open Source, so you can simply go to the source-code and check it. As you can see here, the implementation is quite simple:

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

Quoting the documentation of the function: it "applies fn recursively to every submodule (as returned by .children()) as well as self". Based on the implementation, you can also understand the requirements:

  • fn must be a callable;
  • fn receives as input only a Module object;


来源:https://stackoverflow.com/questions/58795601/pytorch-understand-how-nn-module-class-internally-work

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!