How to get the global_step when restoring checkpoints in Tensorflow?

匿名 (未验证) 提交于 2019-12-03 02:06:01

问题:

I'm saving my session state like so:

self._saver = tf.saver() self._saver.save(self._session, '/network', global_step=self._time) 

When I later restore I want to get the value of the global_step for the checkpoint I restore from. This is in order to set some hyper parameters from it.

The hacky way to do this would be to run through and parse the file names in the checkpoint directory. But surly there has to be a better, built in way to do this?

回答1:

General pattern is to have a global_step variable to keep track of steps

global_step = tf.Variable(0, name='global_step', trainable=False) train_op = optimizer.minimize(loss, global_step=global_step) 

Then you can save with

saver.save(sess, save_path, global_step=global_step) 

When you restore, the value of global_step is restored as well



回答2:

This is a bit of a hack, but the other answers did not work for me at all

ckpt = tf.train.get_checkpoint_state(checkpoint_dir)   #Extract from checkpoint filename step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) 

Update 9/2017

I'm not sure if this started working due to updates, but the following method seems to be effective in getting global_step to update and load properly:

Create two ops. One to hold global_step and another to increment it:

    global_step = tf.Variable(0, trainable=False, name='global_step')     increment_global_step = tf.assign_add(global_step,1,                                             name = 'increment_global_step') 

Now in your training loop run the increment op every time you run your training op.

sess.run([train_op,increment_global_step],feed_dict=feed_dict) 

If you ever want to retrieve you global step value as an integer at any point, just use the following command after loading the model:

sess.run(global_step) 

This can be useful for creating filenames or calculating what your current epoch is without having a second tensorflow Variable for holding that value. For instance, calculating the current epoch on loading would be something like:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records) 


回答3:

I had the same issue as Lawrence Du, I could not find a way to get the global_step by restoring the model. So I applied his hack to the inception v3 training code in the Tensorflow/models github repo I'm using. The code below also contains a fix related to the pretrained_model_checkpoint_path.

If you have a better solution, or know what I'm missing please leave a comment!

In any case, this code works for me:

...  # When not restoring start at 0 last_step = 0 if FLAGS.pretrained_model_checkpoint_path:     # A model consists of three files, use the base name of the model in     # the checkpoint path. E.g. my-model-path/model.ckpt-291500     #     # Because we need to give the base name you can't assert (will always fail)     # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)      variables_to_restore = tf.get_collection(         slim.variables.VARIABLES_TO_RESTORE)     restorer = tf.train.Saver(variables_to_restore)     restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)     print('%s: Pre-trained model restored from %s' %           (datetime.now(), FLAGS.pretrained_model_checkpoint_path))      # HACK : global step is not restored for some unknown reason     last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])      # assign to global step     sess.run(global_step.assign(last_step))  ...  for step in range(last_step + 1, FLAGS.max_steps):    ... 


回答4:

The current 0.10rc0 version seems to be different, there's no tf.saver() any more. Now it's tf.train.Saver(). Also, the save command adds info onto save_path filename for the global_step, so we can't just call restore on the same save_path since that not the actual save file.

The easiest way I see right now is to use the SessionManager along with a saver like this:

my_checkpoint_dir = "/tmp/checkpoint_dir" # make a saver to use with SessionManager for restoring saver = tf.train.Saver() # Build an initialization operation to run below. init = tf.initialize_all_variables() # use a SessionManager to help with automatic variable restoration sm = tf.train.SessionManager() # try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored # if no such checkpoint, then call the init_op after creating a new session sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir)) 

That's it. Now you have a session that's either restored from the my_checkpoint_dir (make sure that directory exists before calling this), or if there's no checkpoint there then it creates a new session and calls the init_op to initialize your variables.

When you want to save, you just save to any name you want in that directory and pass the global_step in. Here's an example where I save the step variable in a loop as the global_step, so it comes back to that point if you kill the program and restart it so it restores the checkpoint:

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) 

This creates files in my_checkpoint_dir like "model.ckpt-1000" where 1000 is the global_step passed in. If it keeps running, then you get more like "model.ckpt-2000". The SessionManager above picks up the latest one of these when the program is restarted. The checkpoint_path can be whatever file name you want, as long as it's in the checkpoint_dir. The save() will create that file with the global_step appended (as shown above). It also creates a "checkpoint" index file, which is how the SessionManager then finds the latest save checkpoint.



回答5:

just note my solution on global step saving and restore.

Save:

global_step = tf.Variable(0, trainable=False, name='global_step') saver.save(sess, model_path + model_name, global_step=_global_step) 

Restore:

if os.path.exists(model_path):     saver.restore(sess, tf.train.latest_checkpoint(model_path))     print("Model restore finished, current globle step: %d" % global_step.eval()) 


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!