PyTorch, nn.Sequential(), access weights of a specific module in nn.Sequential()

后端 未结 3 1201
执笔经年
执笔经年 2021-01-12 09:51

this should be a quick one. When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the modu

3条回答
  •  长情又很酷
    2021-01-12 10:17

    You can access modules by name using _modules:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
    
            self.conv1 = nn.Conv2d(3, 3, 3)
    
        def forward(self, input):
            return self.conv1(input)
    
    model = Net()
    print(model._modules['conv1'])
    

提交回复
热议问题