connect input and output tensors of two different graphs tensorflow

后端 未结 2 788
暖寄归人
暖寄归人 2020-12-05 12:26

I have 2 ProtoBuf Files, I currently load and forward pass each of them separately, by calling-

out1=session.run(graph1out, feed_di         


        
相关标签:
2条回答
  • 2020-12-05 12:54

    Assuming that your Protobuf files contain serialized tf.GraphDef protos, you can use the input_map argument of tf.import_graph_def() to connect the two graphs:

    # Import graph1.
    graph1_def = ...  # tf.GraphDef object
    out1_name = "..."  # name of the graph1out tensor in graph1_def.
    graph1out, = tf.import_graph_def(graph1_def, return_elements=[out_name])
    
    # Import graph2 and connect it to graph1.
    graph2_def = ...  # tf.GraphDef object
    inp2_name = "..."  # name of the graph2inp tensor in graph2_def.
    out2_name = "..."  # name of the graph2out tensor in graph2_def.
    graph2out, = tf.import_graph_def(graph2_def, input_map={inp2_name: graph1out},
                                     return_elements=[out2_name])
    
    0 讨论(0)
  • 2020-12-05 12:59

    Accepted answer does connect two graphs, however it does not restore the collections, global and trainable variables. After an exhaustive search I came to a better solution:

    import tensorflow as tf
    from tensorflow.python.framework import meta_graph
    
    with tf.Graph().as_default() as graph1:
        input = tf.placeholder(tf.float32, (None, 20), name='input')
        output = tf.identity(input, name='output')
    
    with tf.Graph().as_default() as graph2:
        input = tf.placeholder(tf.float32, (None, 20), name='input')
        output = tf.identity(input, name='output')
    
    graph = tf.get_default_graph()
    x = tf.placeholder(tf.float32, (None, 20), name='input')
    

    We use tf.train.export_meta_graph that exports also CollectionDef and meta_graph.import_scoped_meta_graph to import it. Here is where the connection happens, specifically in input_map parameter.

    meta_graph1 = tf.train.export_meta_graph(graph=graph1)
    meta_graph.import_scoped_meta_graph(meta_graph1, input_map={'input': x}, import_scope='graph1')
    out1 = graph.get_tensor_by_name('graph1/output:0')
    
    meta_graph2 = tf.train.export_meta_graph(graph=graph2)
    meta_graph.import_scoped_meta_graph(meta_graph2, input_map={'input': out1}, import_scope='graph2')
    

    Now graph is connected as well as global variables are being re-mapped.

    print(tf.global_variables())
    

    You can also import meta graphs directly from a file.

    0 讨论(0)
提交回复
热议问题