How to get the output from a specific layer from a PyTorch model?

本秂侑毒 提交于 2021-02-18 07:22:18

问题


How to extract the features from a specific layer from a pre-trained PyTorch model (such as ResNet or VGG), without doing a forward pass again?


回答1:


You can register a forward hook on the specific layer you want. Something like:

def some_specific_layer_hook(module, input_, output):
    pass  # the value is in 'output'

model.some_specific_layer.register_forward_hook(some_specific_layer_hook)

model(some_input)

For example, to obtain res5c output in ResNet, you may want to use a nonlocal variable (or global in Python 2):

res5c_output = None

def res5c_hook(module, input_, output):
    nonlocal res5c_output
    res5c_output = output

resnet.layer4.register_forward_hook(res5c_hook)

resnet(some_input)

# Then, use `res5c_output`.


来源:https://stackoverflow.com/questions/52796121/how-to-get-the-output-from-a-specific-layer-from-a-pytorch-model

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!