torchviz方式:
1 from torchviz import make_dot
2 inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH).requires_grad_(True) #有).requires_grad_(True)显示输入形状
3 model = vgg() #model是vgg类的实例
4 vis_graph = make_dot(model(inputs_fake), params=dict(list(model.named_parameters()) + [('x', inputs_fake)]))
5 vis_graph.view()
tensorboardX方式:
from tensorboardX import SummaryWriter
inputs_fake = torch.rand(NUM_SAMPLES, NUM_CHANNELS, HIGTHT, WIDTH)
with SummaryWriter(comment='vgg') as w:
w.add_graph(model, (inputs_fake,))
torchviz生成一个pdf,pdf怎样命名还不知道,或许只能默认命名。