How to convert RGB images to grayscale in PyTorch dataloader?

I've downloaded some sample images from the MNIST dataset in .jpg format. Now I'm loading those images for testing my pre-trained model.

# transforms to apply to the data
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# MNIST dataset
test_dataset = dataset.ImageFolder(root=DATA_PATH, transform=trans)

# Data loader
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Here DATA_PATH contains a subfolder with the sample image.

Here's my network definition

# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.network2D = nn.Sequential(
           nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
           nn.MaxPool2d(kernel_size=2, stride=2),
           nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
           nn.MaxPool2d(kernel_size=2, stride=2))
        self.network1D = nn.Sequential(
           nn.Linear(7 * 7 * 64, 1000),
           nn.Linear(1000, 10))

    def forward(self, x):
        out = self.network2D(x)
        out = out.reshape(out.size(0), -1)
        out = self.network1D(out)
        return out

And this is my inference part

# Test the model
model = torch.load("mnist_weights_5.pth.tar")

for images, labels in test_loader:
   outputs = model(images.cuda())

When I run this code, I get the following error:

RuntimeError: Given groups=1, weight of size [32, 1, 5, 5], expected input[1, 3, 28, 28] to have 1 channels, but got 3 channels instead

I understand that the images are getting loaded as 3 channels (RGB). So how do I convert them to single channel in the dataloader?

Update: I changed transforms to include Grayscale option

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Grayscale(num_output_channels=1)])

But now I get this error

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>


When using ImageFolder class and with no custom loader, pytorch uses PIL to load image and converts it to RGB. Default Loader if torchvision image backend is PIL:

def pil_loader(path): with open(path, 'rb') as f: img = return img.convert('RGB')

You can use torchvision's Grayscale function in transforms. It will convert the 3 channel RGB image into 1 channel grayscale. Find out more about this at

A sample code is below,

import torchvision as tv
import numpy as np
import as data
dataDir         = 'D:\\general\\ML_DL\\datasets\\CIFAR'
trainTransform  = tv.transforms.Compose([tv.transforms.Grayscale(num_output_channels=1),
                                    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainSet        = tv.datasets.CIFAR10(dataDir, train=True, download=False, transform=trainTransform)
dataloader      = data.DataLoader(trainSet, batch_size=1, shuffle=False, num_workers=0)
images, labels  = iter(dataloader).next()
print (images.size())


I found an extremely simple solution to this problem. The required dimensions of the tensor are [1,1,28,28] whereas the input tensor is of the form [1,3,28,28]. So I need to read just 1 channel from it

images = images[:,0,:,:]

This gives me a tensor of the form [1,28,28]. Now I need to convert this to a tensor of the form [1,1,28,28]. Which can be done like this

images = images.unsqueeze(0)

So putting the above two lines together, the prediction part of the code can be written like this

for images, labels in test_loader:
   images = images[:,0,:,:].unsqueeze(0) ## Extract single channel and reshape the tensor
   outputs = model(images.cuda())


You may implement Dataloader not from ImageFolder, but from Datagenerator, directly load images in __getitem__ function."..") then grayscale, to numpy and to Tensor.

Another option is to calculate greyscale(Y) channel from RGB by formula Y = 0.299 R + 0.587 G + 0.114 B. Slice array and convert to one channel.

But how do you train your model? usually train and test data loads in same way.

