I want to convert a pre-trained CNN (like VGG-16) to a fully convolutional network in Pytorch. How can I do so?
Any tip would be helpful.
I want to convert a pre-trained CNN (like VGG-16) to a fully convolutional network in Pytorch. How can I do so?
Any tip would be helpful.
hopefully it's not too late. I'm not entirely sure which parts you want to retain, but if you want to manipulate a pre-trained model or re-use parts of it, my approach would be the following:
Download the pre-trained model, e.g.
import torch import torch.nn as nn from torchvision import models model = models.resnet101(pretrained=True) Extract the parts you're interested in and create a new model from these parts, e.g.
list(model.modules()) # to inspect the modules of your model my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer Of course, you can extract and re-use any parts of the model you may want as well as add new modules, just modify the list accordingly.
You can do that as follows (see comments for description):
import torch import torch.nn as nn from torchvision import models # 1. LOAD PRE-TRAINED VGG16 model = models.vgg16(pretrained=True) # 2. GET CONV LAYERS features = model.features # 3. GET FULLY CONNECTED LAYERS fcLayers = nn.Sequential( # stop at last layer *list(model.classifier.children())[:-1] ) # 4. CONVERT FULLY CONNECTED LAYERS TO CONVOLUTIONAL LAYERS ### convert first fc layer to conv layer with 512x7x7 kernel fc = fcLayers[0].state_dict() in_ch = 512 out_ch = fc["weight"].size(0) firstConv = nn.Conv2d(in_ch, out_ch, 7, 7) ### get the weights from the fc layer firstConv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 7, 7), "bias":fc["bias"]}) # CREATE A LIST OF CONVS convList = [firstConv] # Similarly convert the remaining linear layers to conv layers for layer in enumerate(fcLayers[1:]): if isinstance(module, nn.Linear): # Convert the nn.Linear to nn.Conv fc = module.state_dict() in_ch = fc["weight"].size(1) out_ch = fc["weight"].size(0) conv = nn.Conv2d(in_ch, out_ch, 1, 1) conv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 1, 1), "bias":fc["bias"]}) convList += [conv] else: # Append other layers such as ReLU and Dropout convList += [layer] # Set the conv layers as a nn.Sequential module convLayers = nn.Sequential(*convList)