pytorch的模型解析
如何获取pytorch的动态图? model = torch.jit.load("test.pth") graph = model.graph.copy() torch._C._jit_pass_inline(graph) node_list = graph.nodes() 加载模型后,获取模型的graph,这个graph就是需要的动态图。graph node就是计算图的计算节点(有序),关于各个层的相关参数都可以从node节点中获取,各个参数的相对位置需要查找一下该op的实现。 需要注意的是,需要使用 _jit_pass_inline来将graph的sub module展开。 如何获取pytorch的权重等参数? 对于非量化模型: 可以通过named_parameters或者state_dict获取。 对于量化模型: 在一次次的尝试和接口的设置中终于找到了!!! a = 0 for model_name, module in model.named_modules(): print(model_name) print(module) if a == 2: mod_c = module._c #print(mod_c.dump()) param = module.__getattr__('_packed_params') print(param) print(type(param)