PyTorch, nn.Sequential(), access weights of a specific module in nn.Sequential()

后端 未结 3 1198
执笔经年
执笔经年 2021-01-12 09:51

this should be a quick one. When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the modu

3条回答
  •  滥情空心
    2021-01-12 10:29

    An easy way to access the weights is to use the state_dict() of your model.

    This should work in your case:

    for k, v in model_2.state_dict().iteritems():
        print("Layer {}".format(k))
        print(v)
    

    Another option is to get the modules() iterator. If you know beforehand the type of your layers this should also work:

    for layer in model_2.modules():
       if isinstance(layer, nn.Linear):
            print(layer.weight)
    

提交回复
热议问题