How do I visualize a net in Pytorch?

前端 未结 5 1506
太阳男子
太阳男子 2020-12-08 02:38
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dse         


        
5条回答
  •  醉话见心
    2020-12-08 03:11

    Here is how you do it with torchviz if you want to save the image:

    # http://www.bnikolic.co.uk/blog/pytorch-detach.html
    
    import torch
    from torchviz import make_dot
    
    x=torch.ones(10, requires_grad=True)
    weights = {'x':x}
    
    y=x**2
    z=x**3
    r=(y+z).sum()
    
    make_dot(r).render("attached", format="png")
    

    screenshot of image you get:

    source: http://www.bnikolic.co.uk/blog/pytorch-detach.html

提交回复
热议问题