Loading two models from Saver in the same Tensorflow session

后端 未结 3 1716
后悔当初
后悔当初 2020-12-14 05:17

I have two networks: a Model which generates output and an Adversary which grades the output.

Both have been trained separately but now I n

3条回答
  •  借酒劲吻你
    2020-12-14 05:41

    Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

    To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

    Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

    First I created two graphs

    model_graph = tf.Graph()
    with model_graph.as_default():
        model = Model(args)
    
    adv_graph = tf.Graph()
    with adv_graph.as_default():
        adversary = Adversary(adv_args)
    

    Then two sessions

    adv_sess = tf.Session(graph=adv_graph)
    sess = tf.Session(graph=model_graph)
    

    Then I initialised the variables in each session and restored each graph separately

    with sess.as_default():
        with model_graph.as_default():
            tf.global_variables_initializer().run()
            model_saver = tf.train.Saver(tf.global_variables())
            model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
            model_saver.restore(sess, model_ckpt.model_checkpoint_path)
    
    with adv_sess.as_default():
        with adv_graph.as_default():
            tf.global_variables_initializer().run()
            adv_saver = tf.train.Saver(tf.global_variables())
            adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
            adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
    

    From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

    sess.close()
    adv_sess.close()
    

提交回复
热议问题