Display MNIST image using matplotlib [duplicate]

最后都变了- 提交于 2019-12-03 14:38:06

You are casting an array of floats (as described in the docs) to uint8, which truncates them to 0, if they are not 1.0. You should either round them or use them as floats or multiply with 255.

I am not sure, why you don't see the white background, but i would suggest to use a well defined gray scale anyway.

Here is the complete code for showing image using matplotlib

first_image = mnist.test.images[0]
first_image = np.array(first_image, dtype='float')
pixels = first_image.reshape((28, 28))
plt.imshow(pixels, cmap='gray')
plt.show()

The following code shows example images displayed from the MNIST digit database used for training neural networks. It uses a variety of pieces of code from around stackflow and avoids pil.

# Tested with Python 3.5.2 with tensorflow and matplotlib installed.
from matplotlib import pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
def gen_image(arr):
    two_d = (np.reshape(arr, (28, 28)) * 255).astype(np.uint8)
    plt.imshow(two_d, interpolation='nearest')
    return plt

# Get a batch of two random images and show in a pop-up window.
batch_xs, batch_ys = mnist.test.next_batch(2)
gen_image(batch_xs[0]).show()
gen_image(batch_xs[1]).show()

The definition of mnist is at: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py

The tensorflow neural network that led me to the need to display the MNINST images is at: https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/examples/tutorials/mnist/mnist_deep.py

Since I have only been programming Python for two hours, I might have made some newby errors. Please feel free to correct.

For those of you who want to do it with PIL.Image:

import numpy as np
import PIL.Image as pil
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('mnist')

testImage = (np.array(mnist.test.images[0], dtype='float')).reshape(28,28)

img = pil.fromarray(np.uint8(testImage * 255) , 'L')
img.show()
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!