How to get mini-batches in pytorch in a clean and efficient way?

前端 未结 6 1489
北恋
北恋 2021-01-30 08:48

I was trying to do a simple thing which was train a linear model with Stochastic Gradient Descent (SGD) using torch:

import numpy as np

import torch
from torch.         


        
6条回答
  •  悲&欢浪女
    2021-01-30 09:25

    Use data loaders.

    Data Set

    First you define a dataset. You can use packages datasets in torchvision.datasets or use ImageFolder dataset class which follows the structure of Imagenet.

    trainset=torchvision.datasets.ImageFolder(root='/path/to/your/data/trn', transform=generic_transform)
    testset=torchvision.datasets.ImageFolder(root='/path/to/your/data/val', transform=generic_transform)
    

    Transforms

    Transforms are very useful for preprocessing loaded data on the fly. If you are using images, you have to use the ToTensor() transform to convert loaded images from PIL to torch.tensor. More transforms can be packed into a composit transform as follows.

    generic_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        #transforms.CenterCrop(size=128),
        transforms.Lambda(lambda x: myimresize(x, (128, 128))),
        transforms.ToTensor(),
        transforms.Normalize((0., 0., 0.), (6, 6, 6))
    ])
    

    Data Loader

    Then you define a data loader which prepares the next batch while training. You can set number of threads for data loading.

    trainloader=torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)
    testloader=torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
    

    For training, you just enumerate on the data loader.

      for i, data in enumerate(trainloader, 0):
        inputs, labels = data    
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
        # continue training...
    

    NumPy Stuff

    Yes. You have to convert torch.tensor to numpy using .numpy() method to work on it. If you are using CUDA you have to download the data from GPU to CPU first using the .cpu() method before calling .numpy(). Personally, coming from MATLAB background, I prefer to do most of the work with torch tensor, then convert data to numpy only for visualisation. Also bear in mind that torch stores data in a channel-first mode while numpy and PIL work with channel-last. This means you need to use np.rollaxis to move the channel axis to the last. A sample code is below.

    np.rollaxis(make_grid(mynet.ftrextractor(inputs).data, nrow=8, padding=1).cpu().numpy(), 0, 3)
    

    Logging

    The best method I found to visualise the feature maps is using tensor board. A code is available at yunjey/pytorch-tutorial.

提交回复
热议问题