PyTorch ֮ Datasets

匿名 (未验证) 提交于 2019-12-02 23:56:01

Datasets 的框架:

class CustomDataset(data.Dataset): # 需要继承 data.Dataset     def __init__(self):         # TODO         # Initialize file path or list of file names.         pass              def __getitem__(self, index):         # TODO         # 1. 从文件中读取指定 index 的数据(例:使用 numpy.fromfile, PIL.Image.open)         # 2. 预处理读取的数据(例:torchvision.Transform)         # 3. 返回数据对(例:图像和对应标签)         pass          def __len__(self):         # TODO         # You should change 0 to the total size of your dataset.         return 0

下面是官方 MNIST 的例子:

class MNIST(data.Dataset):     """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.     Args:         root (string): Root directory of dataset where ``processed/training.pt``             and  ``processed/test.pt`` exist.         train (bool, optional): If True, creates dataset from ``training.pt``,             otherwise from ``test.pt``.         download (bool, optional): If true, downloads the dataset from the internet and             puts it in root directory. If dataset is already downloaded, it is not             downloaded again.         transform (callable, optional): A function/transform that  takes in an PIL image             and returns a transformed version. E.g, ``transforms.RandomCrop``         target_transform (callable, optional): A function/transform that takes in the             target and transforms it.     """     urls = [         'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',         'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',         'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',         'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',     ]     raw_folder = 'raw'     processed_folder = 'processed'     training_file = 'training.pt'     test_file = 'test.pt'     classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',                '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']     class_to_idx = {_class: i for i, _class in enumerate(classes)}      @property     def targets(self):         if self.train:             return self.train_labels         else:             return self.test_labels      def __init__(self, root, train=True, transform=None, target_transform=None, download=False):         self.root = os.path.expanduser(root)         self.transform = transform         self.target_transform = target_transform         self.train = train  # training set or test set          if download:             self.download()          if not self._check_exists():             raise RuntimeError('Dataset not found.' +                                ' You can use download=True to download it')          if self.train:             self.train_data, self.train_labels = torch.load(                 os.path.join(self.root, self.processed_folder, self.training_file))         else:             self.test_data, self.test_labels = torch.load(                 os.path.join(self.root, self.processed_folder, self.test_file))      def __getitem__(self, index):         """         Args:             index (int): Index         Returns:             tuple: (image, target) where target is index of the target class.         """         if self.train:             img, target = self.train_data[index], self.train_labels[index]         else:             img, target = self.test_data[index], self.test_labels[index]          # doing this so that it is consistent with all other datasets         # to return a PIL Image         img = Image.fromarray(img.numpy(), mode='L')          if self.transform is not None:             img = self.transform(img)          if self.target_transform is not None:             target = self.target_transform(target)          return img, target      def __len__(self):         if self.train:             return len(self.train_data)         else:             return len(self.test_data)      def _check_exists(self):         return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \             os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))      def download(self):         """Download the MNIST data if it doesn't exist in processed_folder already."""         from six.moves import urllib         import gzip          if self._check_exists():             return          # download files         try:             os.makedirs(os.path.join(self.root, self.raw_folder))             os.makedirs(os.path.join(self.root, self.processed_folder))         except OSError as e:             if e.errno == errno.EEXIST:                 pass             else:                 raise          for url in self.urls:             print('Downloading ' + url)             data = urllib.request.urlopen(url)             filename = url.rpartition('/')[2]             file_path = os.path.join(self.root, self.raw_folder, filename)             with open(file_path, 'wb') as f:                 f.write(data.read())             with open(file_path.replace('.gz', ''), 'wb') as out_f, \                     gzip.GzipFile(file_path) as zip_f:                 out_f.write(zip_f.read())             os.unlink(file_path)          # process and save as torch files         print('Processing...')          training_set = (             read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),             read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))         )         test_set = (             read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),             read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))         )         with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:             torch.save(training_set, f)         with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:             torch.save(test_set, f)          print('Done!')      def __repr__(self):         fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'         fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())         tmp = 'train' if self.train is True else 'test'         fmt_str += '    Split: {}\n'.format(tmp)         fmt_str += '    Root Location: {}\n'.format(self.root)         tmp = '    Transforms (if any): '         fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))         tmp = '    Target Transforms (if any): '         fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))         return fmt_str
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!