初识GAN之MNIST手写数字的识别

南笙酒味 提交于 2019-12-05 10:38:54

初识GAN,因为刚好在尝试用纯python实现手写数字的识别,所以在这里也尝试了一下。笔者也是根据网上教程一步步来的,不多说了,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)#下载文件在MNIST_data文件夹中
sess=tf.InteractiveSession()

def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME') 

def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

x=tf.placeholder(tf.float32,[None,784])
y_=tf.placeholder(tf.float32,[None,10])
x_img=tf.reshape(x,[-1,28,28,1])

#第一个卷积层和池化层
w_conv1=tf.Variable(tf.truncated_normal([3,3,1,32],stddev=0.1))# 生成矩阵,矩阵中元素是均值为0,标准差为0.1的随机数,权值使用方差为0.1的截断正态分布(指最大值不超过方差两倍的分布)来初始化,偏置的初值设定为常值0.1。
b_conv1=tf.Variable(tf.constant(0.1,shape=[32])) #创建一个常数张量,传入list或者数值来填充
h_conv1=tf.nn.relu(conv2d(x_img,w_conv1)+b_conv1) 
h_pool1=max_pool_2x2(h_conv1)

#第二个卷积层和池化层
w_conv2=tf.Variable(tf.truncated_normal([3,3,32,50],stddev=0.1)) 
b_conv2=tf.Variable(tf.constant(0.1,shape=[50])) 
h_conv2=tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2) 
h_pool2=max_pool_2x2(h_conv2)

#第一个全连接层
w_fc1=tf.Variable(tf.truncated_normal([7*7*50,1024],stddev=0.1))
b_fc1=tf.Variable(tf.constant(0.1,shape=[1024])) 
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*50]) 
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)

#dropout(随机权重失活)
keep_prob=tf.placeholder(tf.float32)
h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)

#第二个全连接层
w_fc2=tf.Variable(tf.truncated_normal([1024,10],stddev=0.1)) 
b_fc2=tf.Variable(tf.constant(0.1,shape=[10])) 
y_out=tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)

#loss function,交叉熵,配置Adam优化器,学习速率:1e-4
loss=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_out),reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(1e-4).minimize(loss)

#建立正确率计算表达式
correct_prediction=tf.equal(tf.argmax(y_out,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#测试
tf.global_variables_initializer().run() 
for i in range(20000): 
    batch=mnist.train.next_batch(50) 
    if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1})
        print("step",i, ",train_accuracy",train_accuracy)

print("test_accuracy=",accuracy.eval(feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1}))
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!