pytorch模型中间层特征的提取

匿名 (未验证) 提交于 2019-12-03 00:43:02

转载自:https://blog.csdn.net/LXX516/article/details/80132228

定义一个特征提取的类:

 #中间特征提取  class FeatureExtractor(nn.Module):      def __init__(self, submodule, extracted_layers):          super(FeatureExtractor,self).__init__()          self.submodule = submodule          self.extracted_layers= extracted_layers         def forward(self, x):          outputs = []          for name, module in self.submodule._modules.items():              if name is "fc": x = x.view(x.size(0), -1)              x = module(x)              print(name)              if name in self.extracted_layers:                  outputs.append(x)          return outputs
 #特征输出  myresnet=resnet18(pretrained=False)  myresnet.load_state_dict(torch.load('cafir_resnet18_1.pkl'))   exact_list=["conv1","layer1","avgpool"]  myexactor=FeatureExtractor(myresnet,exact_list)  x=myexactor(img)

在这里主要应用的是:

for nama, module in model._modules.items():

所以要根据自己的情况重写这个类,这个类提供个一个很不错的想法

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!