How to convert pretrained FC layers to CONV layers in Pytorch

前端 未结 1 728
北荒
北荒 2020-12-16 21:44

I want to convert a pre-trained CNN (like VGG-16) to a fully convolutional network in Pytorch. How can I do so?

Any tip would be helpful.

相关标签:
1条回答
  • 2020-12-16 22:23

    You can do that as follows (see comments for description):

    import torch
    import torch.nn as nn
    from torchvision import models
    
    # 1. LOAD PRE-TRAINED VGG16
    model = models.vgg16(pretrained=True)
    
    # 2. GET CONV LAYERS
    features = model.features
    
    # 3. GET FULLY CONNECTED LAYERS
    fcLayers = nn.Sequential(
        # stop at last layer
        *list(model.classifier.children())[:-1]
    )
    
    # 4. CONVERT FULLY CONNECTED LAYERS TO CONVOLUTIONAL LAYERS
    
    ### convert first fc layer to conv layer with 512x7x7 kernel
    fc = fcLayers[0].state_dict()
    in_ch = 512
    out_ch = fc["weight"].size(0)
    
    firstConv = nn.Conv2d(in_ch, out_ch, 7, 7)
    
    ### get the weights from the fc layer
    firstConv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 7, 7),
                               "bias":fc["bias"]})
    
    # CREATE A LIST OF CONVS
    convList = [firstConv]
    
    # Similarly convert the remaining linear layers to conv layers 
    for layer in enumerate(fcLayers[1:]):
        if isinstance(module, nn.Linear):
            # Convert the nn.Linear to nn.Conv
            fc = module.state_dict()
            in_ch = fc["weight"].size(1)
            out_ch = fc["weight"].size(0)
            conv = nn.Conv2d(in_ch, out_ch, 1, 1)
    
            conv.load_state_dict({"weight":fc["weight"].view(out_ch, in_ch, 1, 1),
                "bias":fc["bias"]})
    
            convList += [conv]
        else:
            # Append other layers such as ReLU and Dropout
            convList += [layer]
    
    # Set the conv layers as a nn.Sequential module
    convLayers = nn.Sequential(*convList)  
    
    0 讨论(0)
提交回复
热议问题