Simple way to visualize a TensorFlow graph in Jupyter?

后端 未结 7 2020
滥情空心
滥情空心 2020-11-30 17:24

The official way to visualize a TensorFlow graph is with TensorBoard, but sometimes I just want a quick look at the graph when I\'m working in Jupyter.

Is there a qu

7条回答
  •  清歌不尽
    2020-11-30 17:39

    Here's a recipe I copied from one of Alex Mordvintsev deep dream notebook at some point

    from IPython.display import clear_output, Image, display, HTML
    import numpy as np    
    
    def strip_consts(graph_def, max_const_size=32):
        """Strip large constant values from graph_def."""
        strip_def = tf.GraphDef()
        for n0 in graph_def.node:
            n = strip_def.node.add() 
            n.MergeFrom(n0)
            if n.op == 'Const':
                tensor = n.attr['value'].tensor
                size = len(tensor.tensor_content)
                if size > max_const_size:
                    tensor.tensor_content = ""%size
        return strip_def
    
    def show_graph(graph_def, max_const_size=32):
        """Visualize TensorFlow graph."""
        if hasattr(graph_def, 'as_graph_def'):
            graph_def = graph_def.as_graph_def()
        strip_def = strip_consts(graph_def, max_const_size=max_const_size)
        code = """
            
            
            
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand())) iframe = """ """.format(code.replace('"', '"')) display(HTML(iframe))

    Then to visualize current graph

    show_graph(tf.get_default_graph().as_graph_def())
    

    If your graph is saved as pbtxt, you could do

    gdef = tf.GraphDef()
    from google.protobuf import text_format
    text_format.Merge(open("tf_persistent.pbtxt").read(), gdef)
    show_graph(gdef)
    

    You'll see something like this

提交回复
热议问题