Is there any way, I can print the summary of a model in PyTorch like model.summary() method does in Keras as follows?
Model Summary:
___________
Simply print the model after defining an object for the model class
class RNN(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward():
...
model = RNN(input_dim, embedding_dim, hidden_dim, output_dim)
print(model)