PyTorch:数据读取1
-柚子皮- 什么是Datasets? 在输入流水线中,准备数据的代码是这么写的 data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True) datasets.CIFAR10 就是一个 Datasets 子类, data 是这个类的一个实例。 为什么要定义Datasets? PyTorch 提供了一个工具函数 torch.utils.data.DataLoader 。通过这个类,我们可以让数据变成mini-batch,且在准备 mini-batch 的时候可以多线程并行处理,这样可以加快准备数据的速度。 Datasets 就是构建这个类的实例的参数之一。 DataLoader的使用参考[]。 -柚子皮- 自定义Datasets 框架 import torch.utils.data as data class CustomDataset(data.Dataset): # 继承data.Dataset """Custom data.Dataset compatible with data.DataLoader.""" def __init__(self, filename, data_info, oth_params): """Reads source and target