How can I use a pre-trained neural network with grayscale images?

后端 未结 11 1302
无人及你
无人及你 2020-12-02 09:35

I have a dataset containing grayscale images and I want to train a state-of-the-art CNN on them. I\'d very much like to fine-tune a pre-trained model (like the ones here).

11条回答
  •  天涯浪人
    2020-12-02 10:08

    Converting grayscale images to RGB as per the currently accepted answer is one approach to this problem, but not the most efficient. You most certainly can modify the weights of the model's first convolutional layer and achieve the stated goal. The modified model will both work out of the box (with reduced accuracy) and be finetunable. Modifying the weights of the first layer does not render the rest of the weights useless as suggested by others.

    To do this, you'll have to add some code where the pretrained weights are loaded. In your framework of choice, you need to figure out how to grab the weights of the first convolutional layer in your network and modify them before assigning to your 1-channel model. The required modification is to sum the weight tensor over the dimension of the input channels. The way the weights tensor is organized varies from framework to framework. The PyTorch default is [out_channels, in_channels, kernel_height, kernel_width]. In Tensorflow I believe it is [kernel_height, kernel_width, in_channels, out_channels].

    Using PyTorch as an example, in a ResNet50 model from Torchvision (https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), the shape of the weights for conv1 is [64, 3, 7, 7]. Summing over dimension 1 results in a tensor of shape [64, 1, 7, 7]. At the bottom I've included a snippet of code that would work with the ResNet models in Torchvision assuming that an argument (inchans) was added to specify a different number of input channels for the model.

    To prove this works I did three runs of ImageNet validation on ResNet50 with pretrained weights. There is a slight difference in the numbers for run 2 & 3, but it's minimal and should be irrelevant once finetuned.

    1. Unmodified ResNet50 w/ RGB Images : Prec @1: 75.6, Prec @5: 92.8
    2. Unmodified ResNet50 w/ 3-chan Grayscale Images: Prec @1: 64.6, Prec @5: 86.4
    3. Modified 1-chan ResNet50 w/ 1-chan Grayscale Images: Prec @1: 63.8, Prec @5: 86.1
    def _load_pretrained(model, url, inchans=3):
        state_dict = model_zoo.load_url(url)
        if inchans == 1:
            conv1_weight = state_dict['conv1.weight']
            state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
        elif inchans != 3:
            assert False, "Invalid number of inchans for pretrained weights"
        model.load_state_dict(state_dict)
    
    def resnet50(pretrained=False, inchans=3):
        """Constructs a ResNet-50 model.
        Args:
            pretrained (bool): If True, returns a model pre-trained on ImageNet
        """
        model = ResNet(Bottleneck, [3, 4, 6, 3], inchans=inchans)
        if pretrained:
            _load_pretrained(model, model_urls['resnet50'], inchans=inchans)
        return model
    

提交回复
热议问题