使用cifar-10数据集

匿名 (未验证) 提交于 2019-12-03 00:22:01
#-*-coding:utf-8-*- import numpy as np import matplotlib.pyplot as plt from data_utils import load_CIFAR10 cifar10_dir = 'datasets/cifar-10-batches-py' X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir) print("训练数据:",X_train.shape) classes = ['plane', 'cat', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] num_classes = len(classes) samples_per_class = 7 for y,cls in enumerate(classes):     idxs = np.flatnonzero(y_train == y)     idxs = np.random.choice(idxs,samples_per_class, replace=False)     for i, idx in enumerate(idxs):         plt_idx = i * num_classes + y + 1         plt.subplot(samples_per_class, num_classes, plt_idx)         plt.imshow(X_train[idx].astype('uint8'))         plt.axis('off')         if i == 0:             plt.title(cls) plt.show() 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!