How to graph tf.keras model in Tensorflow-2.0?

后端 未结 4 1155
别那么骄傲
别那么骄傲 2020-12-14 19:11

I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter(\"tf_graphs\", sess.graph). I was looking through some other StackOverflow questions on this

相关标签:
4条回答
  • 2020-12-14 19:47

    You can visualize the graph of any tf.function decorated function, but first, you have to trace its execution.

    Visualizing the graph of a Keras model means to visualize it's call method.

    By default, this method is not tf.function decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.

    import tensorflow as tf
    
    model = tf.keras.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(32, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ]
    )
    
    
    @tf.function
    def traceme(x):
        return model(x)
    
    
    logdir = "log"
    writer = tf.summary.create_file_writer(logdir)
    tf.summary.trace_on(graph=True, profiler=True)
    # Forward pass
    traceme(tf.zeros((1, 28, 28, 1)))
    with writer.as_default():
        tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
    
    0 讨论(0)
  • 2020-12-14 20:00

    Here's what is working for me at the moment (TF 2.0.0), based on the tf.keras.callbacks.TensorBoard code:

    # After model has been compiled
    from tensorflow.python.ops import summary_ops_v2
    from tensorflow.python.keras.backend import get_graph
    tb_path = '/tmp/tensorboard/'
    tb_writer = tf.summary.create_file_writer(tb_path)
    with tb_writer.as_default():
        if not model.run_eagerly:
            summary_ops_v2.graph(get_graph(), step=0)
    
    0 讨论(0)
  • 2020-12-14 20:01

    According to the docs, you can use Tensorboard to visualise graphs once your model has been trained.

    First, define your model and run it. Then, open Tensorboard and switch to the Graph tab.


    Minimal Compilable Example

    This example is taken from the docs. First, define your model and data.

    # Relevant imports.
    %load_ext tensorboard
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    from datetime import datetime
    from packaging import version
    
    import tensorflow as tf
    from tensorflow import keras
    
    # Define the model.
    model = keras.models.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(32, activation='relu'),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    
    (train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
    train_images = train_images / 255.0
    

    Next, train your model. Here, you will need to define a callback for Tensorboard to use for visualising stats and graphs.

    # Define the Keras TensorBoard callback.
    logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
    
    # Train the model.
    model.fit(
        train_images,
        train_labels, 
        batch_size=64,
        epochs=5, 
        callbacks=[tensorboard_callback])
    

    After training, in your notebook, run

    %tensorboard --logdir logs
    

    And switch to the Graph tab in the navbar:

    You will see a graph that looks a lot like this:

    0 讨论(0)
  • 2020-12-14 20:01

    Another option is to use this website: https://lutzroeder.github.io/netron/

    which generate a graph with a .h5 or .tflite file.

    The github repo it's based on may be found here: https://github.com/lutzroeder/netron

    0 讨论(0)
提交回复
热议问题