How to restore variables using CheckpointReader in Tensorflow

后端 未结 3 862
谎友^
谎友^ 2020-12-06 08:42

I\'m trying to restore some variables from checkpoint file if same variable name is in current model.
And I found that there is some way as in Tensorfow Github

S

相关标签:
3条回答
  • 2020-12-06 08:57

    You could use string.split() to get the tensor name:

    ...    
    reader = tf.train.NewCheckpointReader(ckpt_path)
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        print tensor_name
        if reader.has_tensor(tensor_name):
            print 'has tensor'
    ...
    

    Next, let me use an example to show how I would restore every possible variable from a .cpkt file. First, let's save v2 and v3 in tmp.ckpt:

    import tensorflow as tf
    
    v1 = tf.Variable(tf.ones([1]), name='v1')
    v2 = tf.Variable(2 * tf.ones([1]), name='v2')
    v3 = tf.Variable(3 * tf.ones([1]), name='v3')
    
    saver = tf.train.Saver({'v2': v2, 'v3': v3})
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.save(sess, 'tmp.ckpt')
    

    That's how I would restore every variable (belonging to a new graph) showing up in tmp.ckpt:

    with tf.Graph().as_default():
        assert len(tf.trainable_variables()) == 0
        v1 = tf.Variable(tf.zeros([1]), name='v1')
        v2 = tf.Variable(tf.zeros([1]), name='v2')
    
        reader = tf.train.NewCheckpointReader('tmp.ckpt')
        restore_dict = dict()
        for v in tf.trainable_variables():
            tensor_name = v.name.split(':')[0]
            if reader.has_tensor(tensor_name):
                print('has tensor ', tensor_name)
                restore_dict[tensor_name] = v
    
        saver = tf.train.Saver(restore_dict)
        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            saver.restore(sess, 'tmp.ckpt')
            print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]
    

    Also, you may want to ensure that shapes and dtypes match.

    0 讨论(0)
  • 2020-12-06 09:08

    tf.train.NewCheckpointReader is a nifty method that creates a CheckpointReader object. CheckpointReader has several very useful methods. The method that would be the most relevant to your question would be get_variable_to_shape_map().

    • get_variable_to_shape_map() provides a dictionary with variable names and shapes:

    saved_shapes = reader.get_variable_to_shape_map()
    print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']

    Please take a look at this quick tutorial below: Loading Variables from Existing Checkpoints

    0 讨论(0)
  • 2020-12-06 09:09

    Simple answer:

    reader = tf.train.NewCheckpointReader(checkpoint_file)
    
    variable1 = reader.get_tensor('layer_name1/layer_type_name')
    variable2 = reader.get_tensor('layer_name2/layer_type_name')
    
    

    Now, after modification to these variables, you can assign it back.

    layer_name1_var.set_weights([variable1, variable2])
    
    0 讨论(0)
提交回复
热议问题