Tensorflow: how to use pretrained weights in new graph?

前端 未结 2 891
感动是毒
感动是毒 2021-01-05 14:24

I\'m trying to build an object detector with CNN using tensorflow with python framework. I would like to train my model to do just object recognition (classification) at fir

2条回答
  •  夕颜
    夕颜 (楼主)
    2021-01-05 14:40

    Use saver with no arguments to save the entire model.

    tf.reset_default_graph()
    v1 = tf.get_variable("v1", [3], initializer = tf.initializers.random_normal)
    v2 = tf.get_variable("v2", [5], initializer = tf.initializers.random_normal)
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_path='./test-case.ckpt')
    
        print(v1.eval())
        print(v2.eval())
    saver = None
    
    v1 = [ 2.1882825   1.159807   -0.26564872]
    v2 = [0.11437789 0.5742971 ]
    

    Then in the model you want to restore to certain values, pass a list of variable names you want to restore or a dictionary of {"variable name": variable} to the Saver.

    tf.reset_default_graph()
    b1 = tf.get_variable("b1", [3], initializer= tf.initializers.random_normal)
    b2 = tf.get_variable("b2", [3], initializer= tf.initializers.random_normal)
    saver = tf.train.Saver(var_list={'v1': b1})
    
    with tf.Session() as sess:
      saver.restore(sess, "./test-case.ckpt")
      print(b1.eval())
      print(b2.eval())
    
    INFO:tensorflow:Restoring parameters from ./test-case.ckpt
    b1 = [ 2.1882825   1.159807   -0.26564872]
    b2 = FailedPreconditionError: Attempting to use uninitialized value b2
    

提交回复
热议问题