Model summary in pytorch

前端 未结 12 1290
南旧
南旧 2020-12-02 05:45

Is there any way, I can print the summary of a model in PyTorch like model.summary() method does in Keras as follows?

Model Summary:
___________         


        
12条回答
  •  暖寄归人
    2020-12-02 06:45

    Simply print the model after defining an object for the model class

    class RNN(nn.Module):
        def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
            super().__init__()
    
            self.embedding = nn.Embedding(input_dim, embedding_dim)
            self.rnn = nn.RNN(embedding_dim, hidden_dim)
            self.fc = nn.Linear(hidden_dim, output_dim)
        def forward():
            ...
    
    model = RNN(input_dim, embedding_dim, hidden_dim, output_dim)
    print(model)
    

提交回复
热议问题