How to sort MNIST digits into each class label?

孤街浪徒 提交于 2019-12-23 23:26:47

问题


I'm importing mnist dataset from Keras using (x_train, y_train), (x_test, y_test) = mnist.load_data() and what I want to do is sort each sample by it's corresponding digit. I'm imagining some trivial way to do this but I can't seem to find any label attribute of the data. Any simple way to do this?


回答1:


y_train and y_test are the vectors containing the label associated with each image in x_train and x_test respectively. That will tell you the digit shown in each image. So just get the indices that will sort these vectors using np.argsort and then use these indices to re-order the corresponding matrix.

import numpy as np

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]

So if you want all the images for a particular digit, you can simply grab them by indexing the corresponding matrix

x_train_zeros = x_train[y_train == 0]
x_train_ones = x_train[y_train == 1]
# and so on...

Notice that in this case you don't need to pre-sort the data.



来源:https://stackoverflow.com/questions/52618700/how-to-sort-mnist-digits-into-each-class-label

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!