Calling super's forward() method

允我心安 提交于 2020-12-30 05:45:01

问题


What is the most appropriate way to call the forward() method of a parent Module? For example, if I subclass the nn.Linear module, I might do the following

class LinearWithOtherStuff(nn.Linear):
    def forward(self, x):
        y = super(Linear, self).forward(x)
        z = do_other_stuff(y)
        return z

However, the docs say not to call the forward() method directly:

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

which makes me think super(Linear, self).forward(x) could result in some unexpected errors. Is this true or am I misunderstanding inheritance?


回答1:


TLDR;

You can use super().forward(...) freely even with hooks and even with hooks registered in super() instance.

Explanation

As stated by this answer __call__ is here so the registered hooks (e.g. register_forward_hook) will be run.

If you inherit and want to reuse base class's forward, e.g. this:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

You are perfectly fine if you call __call__ method, forward won't run the hook (so you get 3 as above).

It is unlikely you would like to register_hook on the instance of super , but let's consider such example:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

You are perfectly fine using super().forward(...) and even hooks will work correctly (and that is the main idea of using __call__ instead of forward).

BTW. Calling super().__call__(...) would raise InifiniteRecursion error.



来源:https://stackoverflow.com/questions/54752983/calling-supers-forward-method

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!