Tensorflow中保存与恢复模型tf.train.Saver类讲解(恢复部分模型参数的方法)

匿名 (未验证) 提交于 2019-12-03 00:22:01




在做这些之前,先对Saver类说明一下,其中有一个很重要的点要get到:


这个是官网的一个例子,请看下面这一句:

saver = tf.train.Saver(...variables...)

其中这个Saver是一个类,上面的那一句就是通过类取得Saver的对象,里面的variables是构造函数传入的参数,请看这个构造函数对这个参数的解释:

__init__

__init__是构造器,里面可以传很多参数,其中第一个参数就是var_list,也就是上面的variables.

下面是对var_list参数的解释:

Saver.

The constructor adds ops to save and restore variables.

var_listdict

  • dict
  • A list of variables: The variables will be keyed with their op name in the checkpoint files.

注意到红字所表达的意思:




所以里面传的参数是要保存和恢复的变量,举个例子说明问题:

保存参数:

weight=[weights['wc1'],weights['wc2'],weights['wc3a']] saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值 saver.save(sess,'model.ckpt')

上面的意思是,只保存weight里的这些变量,如果saver=tf.train.Saver()里面不传入参数,默认保存全部变量

恢复参数:

weight=[weights['wc1'],weights['wc2'],weights['wc3a']] saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值 saver.restore(sess, model_filename)
上面这个恢复参数要注意,model_filename是你要恢复的模型文件,整段代码的意思是从model_filename文件里只恢复weight的这些参数,如果model_filename里面没有这些参数,则报错。(当然这些变量你不一定都一一列出,你可以通过遍历的算法得到,详细请看下面的参考文献)




像我的这种情况应该怎么恢复变量呢,也是分为两步:

一,恢复部分预训练模型的参数。

weight=[weights['wc1'],weights['wc2'],weights['wc3a']] saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值 saver.restore(sess, model_filename)

二,手动初始化剩下的(预训练模型中没有的)参数。

var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())


保存的时候怎么保存呢?我想保存全部变量,所以要重新写一个对象,名字和恢复的那个saver对象不同:

saver_out=tf.train.Saver() saver_out.save(sess,'file_name')

这个时候就保存了全部变量,如果你想保存部分变量,只需要在构造器里传入想要保存的变量的名字就行了。


更多关于变量恢复的文件类型问题,请参考:

1.https://blog.csdn.net/leo_xu06/article/details/79200634

2.https://blog.csdn.net/b876144622/article/details/79962727




易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!