Mini batches with DataLoader and a 3D input. (Pytorch)

做~自己de王妃 提交于 2021-01-28 06:50:47

问题


I have been struggling to manage and create batches for a 3D tensor. I have used it before as a way to create batches for 1D tensor. However, in my current research, I need to create batches out of a tensor with shape (1024,1024,2).

I created custom data to use as my input for the DataLoader method in pytorch. I created the following for the 1D array:

class CustomDataset(Dataset):
def __init__(self, x_tensor, y_tensor):
    self.xdomain = x_tensor
    self.ydomain = y_tensor
    
def __getitem__(self, index):
    return (self.xdomain[index], self.ydomain[index])

def __len__(self):
    return len(self.xdomain)

It works pretty well, however, I realized that this doesn’t work for tensors x_tensor and y_tensor of shape (1024,1024,2) and (1024,1024,1) respectively. I understand that I have to change the __ getitem __ and __ len __ function in a way so it can divide the tensors into batches.

I tried many things, but one I know it could work is that I could flatten these tensors into shapes (1024 x1024,2) and (1024x1024,1). However, I would have to not only change my NN definition but must of my code.

So I want to keep it as is and try to understand how to create these functions if possible. What I understand of these functions are: __len__ so that len(dataset) returns the size of the dataset. __getitem__ to support the indexing such that dataset[i] can be used to get ith sample.

With this knowledge, I created this class, that finds the indexes of the first 2 dimensions(to find the ith sample). However, this created the input of the NN to be (1024x1024,2) and output (1024x1024,1). And I want it to be (1024,1024,2) and (1024,1024,1).

If someone with a better understanding of Data Loader and mini-batches could explain what am I missing, that could be amazing. An first of all is this possible?

Thanks for reading this, sorry if this question is too basic. I hope is clear.

来源:https://stackoverflow.com/questions/64455422/mini-batches-with-dataloader-and-a-3d-input-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!