How to get mini-batches in pytorch in a clean and efficient way?

前端 未结 6 1490
北恋
北恋 2021-01-30 08:48

I was trying to do a simple thing which was train a linear model with Stochastic Gradient Descent (SGD) using torch:

import numpy as np

import torch
from torch.         


        
6条回答
  •  悲&欢浪女
    2021-01-30 09:39

    You can use torch.utils.data

    assuming you have loaded the data from the directory, in train and test numpy arrays, you can inherit from torch.utils.data.Dataset class to create your dataset object

    class MyDataset(Dataset):
        def __init__(self, x, y):
            super(MyDataset, self).__init__()
            assert x.shape[0] == y.shape[0] # assuming shape[0] = dataset size
            self.x = x
            self.y = y
    
    
        def __len__(self):
            return self.y.shape[0]
    
        def __getitem__(self, index):
            return self.x[index], self.y[index]
    

    Then, create your dataset object

    traindata = MyDataset(train_x, train_y)
    

    Finally, use DataLoader to create your mini-batches

    trainloader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
    

提交回复
热议问题