tensorflow2.0手写数字识别

旧时模样 提交于 2019-12-01 16:02:06
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


datapath  = r'D:\data\ml\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(datapath)

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3)


val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)

i = 103
plt.imshow(x_test[i],cmap=plt.cm.binary)
plt.show()

predictions = model.predict(x_test)
print(np.argmax(predictions[i]))

其中mnist.npz文件可以从google下载 

https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!