Tensorflow: How to replace a node in a calculation graph?

前端 未结 4 607
日久生厌
日久生厌 2020-11-27 14:38

If you have two disjoint graphs, and want to link them, turning this:

x = tf.placeholder(\'float\')
y = f(x)

y = tf.placeholder(\'float\')
z = f(y)
<         


        
4条回答
  •  时光说笑
    2020-11-27 15:13

    Practical example:

    import tensorflow as tf
    g1 = tf.Graph()
    with g1.as_default():
        # set variables/placeholders
        tf.placeholder(tf.int32, [], name='g1_a')
        tf.placeholder(tf.int32, [], name='g1_b')
    
        # example on exacting tensor by name
        a = g1.get_tensor_by_name('g1_a:0')
        b = g1.get_tensor_by_name('g1_b:0')
    
        # operation ==>>     c = 2 * 3 = 6
        mul_op = tf.multiply(a, b, name='g1_mul')
        sess = tf.Session()
        g1_mul_results = sess.run(mul_op, feed_dict={'g1_a:0': 2, 'g1_b:0': 3})
        print('graph1 mul = ', g1_mul_results)  # output = 6
    
        print('\ngraph01 operations/variables:')
        for op in g1.get_operations():
            print(op.name)
    
    g2 = tf.Graph()
    with g2.as_default():
        # set variables/placeholders
        tf.import_graph_def(g1.as_graph_def())
        g2_c = tf.placeholder(tf.int32, [], name='g2_c')
    
        # example on exacting tensor by name
        g1_b = g2.get_tensor_by_name('import/g1_b:0')
        g1_mul = g2.get_tensor_by_name('import/g1_mul:0')
    
        # operation ==>>
        b = tf.multiply(g1_b, g2_c, name='g2_var_times_g1_a')
        f = tf.multiply(g1_mul, g1_b, name='g1_mul_times_g1_b')
    
        print('\ngraph01 operations/variables:')
        for op in g2.get_operations():
            print(op.name)
        sess = tf.Session()
    
        # graph1 variable 'a' times graph2 variable 'c'(graph2)
        ans = sess.run('g2_var_times_g1_a:0', feed_dict={'g2_c:0': 4, 'import/g1_b:0': 5})
        print('\ngraph2 g2_var_times_g1_a = ', ans)  # output = 20
    
        # graph1 mul_op (a*b) times graph1 variable 'b'
        ans = sess.run('g1_a_times_g1_b:0',
                       feed_dict={'import/g1_a:0': 6, 'import/g1_b:0': 7})
        print('\ngraph2 g1_mul_times_g1_b:0 = ', ans)  # output = (6*7)*7 = 294
    
    ''' output
    graph1 mul =  6
    
    graph01 operations/variables:
    g1_a
    g1_b
    g1_mul
    
    graph01 operations/variables:
    import/g1_a
    import/g1_b
    import/g1_mul
    g2_c
    g2_var_times_g1_a
    g1_a_times_g1_b
    
    graph2 g2_var_times_g1_a =  20
    
    graph2 g1_a_times_g1_b:0 =  294
    '''
    

    reference LINK

提交回复
热议问题