TensorFlow 2.0 Keras: How to write image summaries for TensorBoard

后端 未结 2 582
悲哀的现实
悲哀的现实 2020-12-03 06:25

I\'m trying to setup an image recognition CNN with TensorFlow 2.0. To be able to analyze my image augmentation I\'d like to see the images I feed into the network in tensorb

2条回答
  •  生来不讨喜
    2020-12-03 06:43

    You could do something like this to add input image to tensorboard

    def scale(image, label):
        return tf.cast(image, tf.float32) / 255.0, label
    
    
    def augment(image, label):
        return image, label  # do nothing atm
    
    
    file_writer = tf.summary.create_file_writer(logdir + "/images")
    
    
    def plot_to_image(figure):
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close(figure)
        buf.seek(0)
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        return image
    
    
    def image_grid():
        """Return a 5x5 grid of the MNIST images as a matplotlib figure."""
        # Create a figure to contain the plot.
        figure = plt.figure(figsize=(10, 10))
        for i in range(25):
            # Start next subplot.
            plt.subplot(5, 5, i + 1, title=str(y_train[i]))
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            image, _ = scale(x_train[i], y_train[i])
            plt.imshow(x_train[i], cmap=plt.cm.binary)
    
        return figure
    
    
    # Prepare the plot
    figure = image_grid()
    # Convert to image and log
    with file_writer.as_default():
        tf.summary.image("Training data", plot_to_image(figure), step=0)
    
    dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    dataset = dataset.map(scale).map(augment).batch(32)
    
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    model.fit(dataset, epochs=5, callbacks=[tf.keras.callbacks.TensorBoard(log_dir=logdir)])
    

提交回复
热议问题