Is it possible to modify an existing TensorFlow computation graph?

前端 未结 3 874
梦谈多话
梦谈多话 2020-12-05 14:42

TensorFlow graph is usually built gradually from inputs to outputs, and then executed. Looking at the Python code, the inputs lists of operations are immutable which suggest

相关标签:
3条回答
  • 2020-12-05 14:49

    Yes, tf.Graph are build in an append-only fashion as @mrry puts it.

    But there's workaround:

    Conceptually you can modify an existing graph by cloning it and perform the modifications needed along the way. As of r1.1, Tensorflow provides a module named tf.contrib.graph_editor which implements the above idea as a set of convinient functions.

    0 讨论(0)
  • 2020-12-05 15:03

    The TensorFlow tf.Graph class is an append-only data structure, which means that you can add nodes to the graph after executing part of the graph, but you cannot remove or modify existing nodes. Since TensorFlow executes only the necessary subgraph when you call Session.run(), there is no execution-time cost to having redundant nodes in the graph (although they will continue to consume memory).

    To remove all nodes in the graph, you can create a session with a new graph:

    with tf.Graph().as_default():  # Create a new graph, and make it the default.
      with tf.Session() as sess:  # `sess` will use the new, currently empty, graph.
        # Build graph and execute nodes in here.
    
    0 讨论(0)
  • 2020-12-05 15:06

    In addition to what @zaxily and @mrry says, I want to provide an example of how to actually do a modification to the graph. In short:

    1. one can not modify existing operations, all ops are final and non-mutable
    2. one may copy an op, modify it's inputs or attributes and add new op back to the graph
    3. all downstream ops that depend on the new/copied op have to be recreated. Yes, a signifficant portion of the graph would be copied copied, which is not a problem

    The code:

    import tensorflow
    import copy
    import tensorflow.contrib.graph_editor as ge
    from copy import deepcopy
    
    a = tf.constant(1)
    b = tf.constant(2)
    c = a+b
    
    def modify(t): 
        # illustrate operation copy&modification
        new_t = deepcopy(t.op.node_def)
        new_t.name = new_t.name+"_but_awesome"
        new_t = tf.Operation(new_t, tf.get_default_graph())
        # we got a tensor, let's return a tensor
        return new_t.outputs[0]
    
    def update_existing(target, updated):
        # illustrate how to use new op
        related_ops = ge.get_backward_walk_ops(target, stop_at_ts=updated.keys(), inclusive=True)
        new_ops, mapping = ge.copy_with_input_replacements(related_ops, updated)
        new_op = mapping._transformed_ops[target.op]
        return new_op.outputs[0]
    
    new_a = modify(a)
    new_b = modify(b)
    injection = new_a+39 # illustrate how to add another op to the graph
    new_c = update_existing(c, {a:injection, b:new_b})
    
    with tf.Session():
        print(c.eval()) # -> 3
        print(new_c.eval()) # -> 42
    
    0 讨论(0)
提交回复
热议问题