How do you alter the size of a Pytorch Dataset?

≡放荡痞女 提交于 2020-01-02 03:56:07

问题


Say I am loading MNIST from torchvision.datasets.MNIST, but I only want to load in 10000 images total, how would I slice the data to limit it to only some number of data points? I understand that the DataLoader is a generator yielding data in the size of the specified batch size, but how do you slice datasets?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

回答1:


It is important to note that when you create the DataLoader object, it doesnt immediately load all of your data (its impractical for large datasets). It provides you an iterator that you can use to access each sample.

Unfortunately, DataLoader doesnt provide you with any way to control the number of samples you wish to extract. You will have to use the typical ways of slicing iterators.

Simplest thing to do (without any libraries) would be to stop after the required number of samples is reached.

nsamples = 10000
for i, image, label in enumerate(train_loader):
    if i > nsamples:
        break

    # Your training code here.

Or, you could use itertools.islice to get the first 10k samples. Like so.

for image, label in itertools.islice(train_loader, stop=10000):

    # your training code here.



回答2:


Another quick way of slicing dataset is by using torch.utils.data.random_split() (supported in PyTorch v0.4.1+). It helps in randomly splitting a dataset into non-overlapping new datasets of given lengths.

So we can have something like the following:

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)

part_tr = torch.utils.data.random_split(tr, [tr_split_len, len(tr)-tr_split_len])[0]
part_te = torch.utils.data.random_split(te, [te_split_len, len(te)-te_split_len])[0]

train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(part_te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

here you can set tr_split_len and te_split_len as the required split lengths for training and testing datasets respectively.



来源:https://stackoverflow.com/questions/44856691/how-do-you-alter-the-size-of-a-pytorch-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!