Deriving the structure of a pytorch network

我是研究僧i 提交于 2021-02-04 08:03:10

问题


For my use case, I require to be able to take a pytorch module and interpret the sequence of layers in the module so that I can create a “connection” between the layers in some file format. Now let’s say I have a simple module as below

class mymodel(nn.Module):
    def __init__(self, input_channels):
        super(mymodel, self).__init__()
        self.fc = nn.Linear(input_channels, input_channels)
    def forward(self, x):
        out = self.fc(x)
        out += x
        return out


if __name__ == "__main__":
    net = mymodel(5)

    for mod in net.modules():
        print(mod) 

Here the output yields:

mymodel(
  (fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

as you can see the information about the plus equals operation or plus operation is not captured as it is not a nnmodule in the forward function. My goal is to be able to create a graph connection from the pytorch module object to say something like this in json :

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

The input tensor names are arbitrary but it captures the essence of the graph and the connections between layers.

My question is, is there a way to extract the information from a pytorch object? I was thinking to use the .modules() but then realized that hand written operations are not captured this way as a module. I guess if everything is an nn.module then the .modules() might give me the network layer arrangement. Looking for some help here. I want to be able to know the connections between tensors to create a format as above.


回答1:


The information you are looking for is not stored in the nn.Module, but rather in the grad_fn attribute of the output tensor:

model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn  # all the information is in the computation graph of the output tensor

It is not trivial to extract this information. You might want to look at torchviz package that draws a nice graph from the grad_fn information.



来源:https://stackoverflow.com/questions/58253003/deriving-the-structure-of-a-pytorch-network

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!