TF LSTM: Save State from training session for prediction session later

≡放荡痞女 提交于 2019-12-02 05:09:28

The issue is that creating a new tf.Variable after the Saver was constructed means that the Saver has no knowledge of the new variable. It still gets saved in the metagraph, but not saved in the checkpoint:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

I've annotated the quick reproduction of your issue above with the variables that the Saver knows about.

Now, the solution is relatively easy. I would suggest creating the Variable before the Saver, then using tf.assign to update its value (make sure you run the op returned by tf.assign). The assigned value will be saved in checkpoints and restored just like other variables.

This could be handled better by the Saver as a special case when None is passed to its var_list constructor argument (i.e. it could pick up new variables automatically). Feel free to open a feature request on Github for this.

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!