MNIST 数据集已经是一个被”嚼烂”了的数据集, 很多教程都会对它”下手”, 几乎成为一个 “典范”. 不过有些人可能对它还不是很了解, 下面来介绍一下.
http://yann.lecun.com/exdb/mnist/
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
图片是以字节的形式进行存储, 我们需要把它们读取到 NumPy array 中, 以便训练和测试算法.
import os import struct import numpy as np def load_mnist(path, kind='train'): """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) return images, labels
load_mnist
images
load_mnist
labels
) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).
第一次见的话, 可能会觉得我们读取图片的方式有点奇怪:
magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8)
- 为了理解这两行代码, 我们先来看一下 MNIST 网站上对数据集的介绍:
TRAINING SET LABEL FILE (train-labels-idx1-ubyte): [offset] [type] [value] [description] 0000 32 bit integer 0x00000801(2049) magic number (MSB first) 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label The labels values are 0 to 9.
- magic number
fromfile
item 数(n)struct.unpack
>II
>
EndiannessI
: 这是指一个无符号整数.
通过执行下面的代码, 我们将会从刚刚解压 MNIST 数据集后的 mnist 目录下加载 60,000 个训练样本和 10,000 个测试样本.
imshow
import matplotlib.pyplot as plt fig, ax = plt.subplots( nrows=2, ncols=5, sharex=True, sharey=True, ) ax = ax.flatten() for i in range(10): img = X_train[y_train == i][0].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest') ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()
我们现在应该可以看到一个 2*5 的图片, 里面分别是 0-9 单个数字的图片.
此外, 我们还可以绘制某一数字的多个样本图片, 来看一下这些手写样本到底有多不同:
fig, ax = plt.subplots( nrows=5, ncols=5, sharex=True, sharey=True, ) ax = ax.flatten() for i in range(25): img = X_train[y_train == 7][i].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest') ax[0].set_xticks([]) ax[0].set_yticks([]) plt.tight_layout() plt.show()
执行上面的代码后, 我们应该看到数字 7 的 25 个不同形态:
另外, 我们也可以选择将 MNIST 图片数据和标签保存为 CSV 文件, 这样就可以在不支持特殊的字节格式的程序中打开数据集. 但是, 有一点要说明, CSV 的文件格式将会占用更多的磁盘空间, 如下所示:
- train_img.csv: 109.5 MB
- train_labels.csv: 120 KB
- test_img.csv: 18.3 MB
- test_labels: 20 KB
如果我们打算保存这些 CSV 文件, 在将 MNIST 数据集加载入 NumPy array 以后, 我们应该执行下列代码:
np.savetxt('train_img.csv', X_train, fmt='%i', delimiter=',') np.savetxt('train_labels.csv', y_train, fmt='%i', delimiter=',') np.savetxt('test_img.csv', X_test, fmt='%i', delimiter=',') np.savetxt('test_labels.csv', y_test, fmt='%i', delimiter=',')
genfromtxt
X_train = np.genfromtxt('train_img.csv', dtype=int, delimiter=',') y_train = np.genfromtxt('train_labels.csv', dtype=int, delimiter=',') X_test = np.genfromtxt('test_img.csv', dtype=int, delimiter=',') y_test = np.genfromtxt('test_labels.csv', dtype=int, delimiter=',')
不过, 从 CSV 文件中加载 MNIST 数据将会显著发给更长的时间, 因此如果可能的话, 还是建议你维持数据集原有的字节格式.
出处:https://blog.csdn.net/simple_the_best/article/details/75267863