Tensorflow: how to save/restore a model?

前端 未结 26 3258
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  庸人自扰
    2020-11-21 12:23

    You can also take this easier way.

    Step 1: initialize all your variables

    W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
    B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")
    
    Similarly, W2, B2, W3, .....
    

    Step 2: save the session inside model Saver and save it

    model_saver = tf.train.Saver()
    
    # Train the model and save it in the end
    model_saver.save(session, "saved_models/CNN_New.ckpt")
    

    Step 3: restore the model

    with tf.Session(graph=graph_cnn) as session:
        model_saver.restore(session, "saved_models/CNN_New.ckpt")
        print("Model restored.") 
        print('Initialized')
    

    Step 4: check your variable

    W1 = session.run(W1)
    print(W1)
    

    While running in different python instance, use

    with tf.Session() as sess:
        # Restore latest checkpoint
        saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))
    
        # Initalize the variables
        sess.run(tf.global_variables_initializer())
    
        # Get default graph (supply your custom graph if you have one)
        graph = tf.get_default_graph()
    
        # It will give tensor object
        W1 = graph.get_tensor_by_name('W1:0')
    
        # To get the value (numpy array)
        W1_value = session.run(W1)
    

提交回复
热议问题