How do I visualize a net in Pytorch?

前端 未结 5 1495
太阳男子
太阳男子 2020-12-08 02:38
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dse         


        
5条回答
  •  自闭症患者
    2020-12-08 02:51

    make_dot expects a variable (i.e., tensor with grad_fn), not the model itself.
    try:

    x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
    out = resnet(x)
    make_dot(out)  # plot graph of variable, not of a nn.Module
    

提交回复
热议问题