Tensorflow define · Dyting's Blog

故事扮演 提交于 2019-11-28 16:19:35

Tensorflow 概念

Tensor

Tensor是TensorFlow中主要的数据结构,是一个多维数组。例如可以讲一小组图像集表示成一个四维的浮点数数组,这四个维度分别是[batch,height,width,channels].

创建tensor有两种方式,一是直接用tensorflow自带的函数创建,二是用Python的numpy库创建。
第一种如下:

12
import tensorflow as tftf.zeros([row_dim,col_dim])

第二种方式:

123
import numpy as npx_data=np.array([[1,2,3],[4,5,6]])tf.convert_to_tensor(x_data,dtype=tf.float32)

graph图

一个完整的TF代码主要分成2个部分,定义(构建)和执行。通常在构建阶段创建一个图表示和训练神经网络。
如下图所示:


Tensor在一个或者多个由节点和边组成的图中流动,边代表tensors,节点代表对tensors的操作。
如图,节点a接收了一个1-D tensor,该tensor从节点a流出后,分别流向节点b和c。。
一旦开始一个任务,一个默认的图已经创建好了,可以通过调用tf.get_default_graph)()来访问。在默认的图里面添加操作,如下例子所示:
12345
import tensorflow as tfimport numpy as npc=tf.constant(value=1)print(c.graph)print(tf.get_default_graph)

另一个用法就是可以创建图覆盖默认图:

12345
import tensorflow as tfimport numpy as npwith tf.Graph.as_default() as g:     d=tf.constant(value=1)     print(d.graph)

Session 会话

TF阶段分为两个构建和执行,构建是建造图,执行就是启动图,启动图的第一步就是创建一盒Session对象。

1234567
#启动默认图sess=tf.Session()........result=sess.run(d)#结束关闭sess.close()

可以用with代码自动关闭进程。

12
with tf.Session() as sess:     result=sess.run([product])

Session可以交互使用,避免一个变量持有会话。
用InteractiveSession代替Session,使用Tensor.eval()和Operation.run()方法代替Session.run().

12345678910
import tensorflow as tfsess=tf.InteractiveSession()x=tf.Variable([1,2])a=tf.constant([3,3])x.initializer.run()#变量必须初始化才能使用。sub=tf.sub(x,a)print sub.eval()

在执行run之前操作都不会被真正的执行
Session.run()方法有两个参数,分别是fetches和feed_dict. 传递给fetches的参数既可以是tensor也可以是operation,也可以是list。feed_dict 作用是替换图中某个tensor的值。

123456789101112131415161718
sess.run(fetches=d)sess.run(d)a=tf.add(2,5)b=tf.multiply(a,3)sess=tf.Session()sess.run(b)replace_dict={a:15}sess.run(b,feed_dict=replace_dict)#结果是45,直接提供a=15,不经过2+5#feed_dict用来设置graph的输入值input1=tf.placeholder(tf.float32)input2=tf.placeholder(tf.float32)output=tf.mul(input1,input2)with tf.Session() as sess:     print sess.run([output],feed_dict={input1:[7.],input2:[2.])

placeholder

上面提到placeholder,input1和input2不是 tensor而是placeholder,没有具体的值,在后面的使用和一般的tensor一样,不过在运行的时候,需要用feed_dict 把具体的值提供给placeholder。

Variable 变量

可以改变自身状态的值,但是必须要在Session中初始化才能显示。

12345678910
my_var=tf.Variable(1)my_var_times_two=my_var.assign(my_var*2)#赋值#初始化操作init=tf.global_variable_initializer()sess=tf.Session()#初始化变量sess.run(init)print sess.run(my_var_times_two)

TensorFlow-进阶

保存和加载模型

保存

利用Saver类实现,它处理图中数据的保存和恢复,我们需要做的就是告诉Saver类我们需要保存哪个图和哪些变量。默认情况下,Sever处理默认图中所有变量,但是,也可以创建更多的图来保存任何想要的子图。

12345678910111213141516
import tensorflow as tfv1=tf.Variable(1,name='v1')v2=tf.Variable(2,name='v2')a=tf.add(v1,v2)all_saver=tf.train.Saver()#保存所有的变量v2_saver=tf.tf.train.Saver("v2":v2)#保存想要保存的变量with tf.Sessin() as sess:     #初始化所有变量     sess.run(tf.global_variables_initializer())     #保存变量     all_saver.save(sess,'data.chpk')          v2_sever.save(sess,'data-v2.chkp')

提取

1234567891011121314
# Create some variables.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, "/tmp/model.ckpt")  print "Model restored."  # Do some work with the model

checkpoint文档就是用来保存变量的。

原文链接 大专栏  https://www.dazhuanlan.com/2019/08/24/5d6119336bc29/

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!