Tensorflow: How to replace a node in a calculation graph?

前端 未结 4 598
日久生厌
日久生厌 2020-11-27 14:38

If you have two disjoint graphs, and want to link them, turning this:

x = tf.placeholder(\'float\')
y = f(x)

y = tf.placeholder(\'float\')
z = f(y)
<         


        
4条回答
  •  执笔经年
    2020-11-27 15:04

    TL;DR: If you can define the two computations as Python functions, you should do that. If you can't, there's more advanced functionality in TensorFlow to serialize and import graphs, which allows you to compose graphs from different sources.

    One way to do this in TensorFlow is to build the disjoint computations as separate tf.Graph objects, then convert them to serialized protocol buffers using Graph.as_graph_def():

    with tf.Graph().as_default() as g_1:
      input = tf.placeholder(tf.float32, name="input")
      y = f(input)
      # NOTE: using identity to get a known name for the output tensor.
      output = tf.identity(y, name="output")
    
    gdef_1 = g_1.as_graph_def()
    
    with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
      input = tf.placeholder(tf.float32, name="input")
      z = g(input)
      output = tf.identity(y, name="output")
    
    gdef_2 = g_2.as_graph_def()
    

    Then you could compose gdef_1 and gdef_2 into a third graph, using tf.import_graph_def():

    with tf.Graph().as_default() as g_combined:
      x = tf.placeholder(tf.float32, name="")
    
      # Import gdef_1, which performs f(x).
      # "input:0" and "output:0" are the names of tensors in gdef_1.
      y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                               return_elements=["output:0"])
    
      # Import gdef_2, which performs g(y)
      z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                               return_elements=["output:0"]
    

提交回复
热议问题