Building Tensorflow Graphs Inside of Functions

戏子无情 提交于 2019-12-04 13:44:29

问题


I'm learning Tensorflow and am trying to properly structure my code. I (more or less) know how to build graphs either bare or as class methods, but I'm trying to figure out how best to structure the code. I've tried the simple example:

def build_graph():                
     g = tf.Graph()     
     with g.as_default():                       
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g   

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}      
     print(sess.run(b, feed_dict=feed))

which should just print out 4. However, when I do that, I get the error:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.

I'm pretty sure this is because the placeholder inside the function build_graph is private, but shouldn't the with tf.Session(graph=graph) take care of that? Is there a better way of using a feed dict in a situation like this?


回答1:


There are several options.

Option 1: just pass the name of the tensor instead of the tensor itself.

with tf.Session(graph=graph) as sess:
    feed = {"Placeholder:0": 3}      
    print(sess.run("Add:0", feed_dict=feed))

In this case, it's probably best to give the nodes meaningful names, instead of using the default names as above:

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))

Recall that the outputs of an operation named "foo" are tensors named "foo:0", "foo:1", and so on. Most operations have just one output.

Option 2: make your build_graph() function return all the important nodes.

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

Option 3: add important nodes to a collection

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

Option 4: as suggested by @pohe you can use get_tensor_by_name()

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

I personally use option 2 most often, it's pretty straightforward and doesn't require playing with names. I use option 3 when the graph is large and will live for a long time, because the collection gets saved along with the model, and it's a quick way to document what really matters. I don't really use option 1, because I prefer to have actual references to objects (not sure why). Option 4 is useful when you are working with a graph built by someone else, and they didn't give you direct references to tensors.

Hope this helps!




回答2:


I'm looking for a better way as well, so my answer is probably not the best. Nevertheless, if you give a and b a name, such as

a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')

Then you can do

graph = build_graph()

a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')

with tf.Session(graph=graph) as sess:
    feed = {a: 3}      
    print(sess.run(b, feed_dict=feed))

p.s. naming a and b is not necessary. it's just easier to reference later. Also, if you've found a better solution for it, please share it too.



来源:https://stackoverflow.com/questions/44418442/building-tensorflow-graphs-inside-of-functions

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!