Tensorflow: How to replace a node in a calculation graph?

前端 未结 4 605
日久生厌
日久生厌 2020-11-27 14:38

If you have two disjoint graphs, and want to link them, turning this:

x = tf.placeholder(\'float\')
y = f(x)

y = tf.placeholder(\'float\')
z = f(y)
<         


        
4条回答
  •  刺人心
    刺人心 (楼主)
    2020-11-27 15:13

    If you want to combine trained models (for example to reuse parts of a pretrained model in a new model), you can use a Saver to save a checkpoint of the first model, then restore that model (entirely or partially) into another model.

    For example, say you want to reuse model 1's weights w in model 2, and also convert x from a placeholder to a variable:

    with tf.Graph().as_default() as g1:
        x = tf.placeholder('float')
        w = tf.Variable(1., name="w")
        y = x * w
        saver = tf.train.Saver()
    
    with tf.Session(graph=g1) as sess:
        w.initializer.run()
        # train...
        saver.save(sess, "my_model1.ckpt")
    
    with tf.Graph().as_default() as g2:
        x = tf.Variable(2., name="v")
        w = tf.Variable(0., name="w")
        z = x + w
        restorer = tf.train.Saver([w]) # only restore w
    
    with tf.Session(graph=g2) as sess:
        x.initializer.run()  # x now needs to be initialized
        restorer.restore(sess, "my_model1.ckpt") # restores w=1
        print(z.eval())  # prints 3.
    

提交回复
热议问题