How to get mini-batches in pytorch in a clean and efficient way?

前端 未结 6 1501
北恋
北恋 2021-01-30 08:48

I was trying to do a simple thing which was train a linear model with Stochastic Gradient Descent (SGD) using torch:

import numpy as np

import torch
from torch.         


        
6条回答
  •  無奈伤痛
    2021-01-30 09:25

    If I'm understanding your code correctly, your get_batch2 function appears to be taking random mini-batches from your dataset without tracking which indices you've used already in an epoch. The issue with this implementation is that it likely will not make use of all of your data.

    The way I usually do batching is creating a random permutation of all the possible vertices using torch.randperm(N) and loop through them in batches. For example:

    n_epochs = 100 # or whatever
    batch_size = 128 # or whatever
    
    for epoch in range(n_epochs):
    
        # X is a torch Variable
        permutation = torch.randperm(X.size()[0])
    
        for i in range(0,X.size()[0], batch_size):
            optimizer.zero_grad()
    
            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X[indices], Y[indices]
    
            # in case you wanted a semi-full example
            outputs = model.forward(batch_x)
            loss = lossfunction(outputs,batch_y)
    
            loss.backward()
            optimizer.step()
    

    If you like to copy and paste, make sure you define your optimizer, model, and lossfunction somewhere before the start of the epoch loop.

    With regards to your error, try using torch.from_numpy(np.random.randint(0,N,size=M)).long() instead of torch.LongTensor(np.random.randint(0,N,size=M)). I'm not sure if this will solve the error you are getting, but it will solve a future error.

提交回复
热议问题