How to display custom images in TensorBoard using Keras?

前端 未结 8 1393
暗喜
暗喜 2020-12-02 12:16

I\'m working on a segmentation problem in Keras and I want to display segmentation results at the end of every training epoch.

I want something similar to Tensorflow

8条回答
  •  天命终不由人
    2020-12-02 13:16

    I'm trying to display matplotlib plots to the tensorboard (useful incases of plotting statistics, heatmaps, etc). It can be used for the general case also.

    class AttentionLogger(keras.callbacks.Callback):
            def __init__(self, val_data, logsdir):
                super(AttentionLogger, self).__init__()
                self.logsdir = logsdir  # where the event files will be written 
                self.validation_data = val_data # validation data generator
                self.writer = tf.summary.FileWriter(self.logsdir)  # creating the summary writer
    
            @tfmpl.figure_tensor
            def attention_matplotlib(self, gen_images): 
                '''
                Creates a matplotlib figure and writes it to tensorboard using tf-matplotlib
                gen_images: The image tensor of shape (batchsize,width,height,channels) you want to write to tensorboard
                '''  
                r, c = 5,5  # want to write 25 images as a 5x5 matplotlib subplot in TBD (tensorboard)
                figs = tfmpl.create_figures(1, figsize=(15,15))
                cnt = 0
                for idx, f in enumerate(figs):
                    for i in range(r):
                        for j in range(c):    
                            ax = f.add_subplot(r,c,cnt+1)
                            ax.set_yticklabels([])
                            ax.set_xticklabels([])
                            ax.imshow(gen_images[cnt])  # writes the image at index cnt to the 5x5 grid
                            cnt+=1
                    f.tight_layout()
                return figs
    
            def on_train_begin(self, logs=None):  # when the training begins (run only once)
                    image_summary = [] # creating a list of summaries needed (can be scalar, images, histograms etc)
                    for index in range(len(self.model.output)):  # self.model is accessible within callback
                        img_sum = tf.summary.image('img{}'.format(index), self.attention_matplotlib(self.model.output[index]))                    
                        image_summary.append(img_sum)
                    self.total_summary = tf.summary.merge(image_summary)
    
            def on_epoch_end(self, epoch, logs = None):   # at the end of each epoch run this
                logs = logs or {} 
                x,y = next(self.validation_data)  # get data from the generator
                # get the backend session and sun the merged summary with appropriate feed_dict
                sess_run_summary = K.get_session().run(self.total_summary, feed_dict = {self.model.input: x['encoder_input']})
                self.writer.add_summary(sess_run_summary, global_step =epoch)  #finally write the summary!
    
    

    Then you will have to give it as an argument to fit/fit_generator

    #val_generator is the validation data generator
    callback_image = AttentionLogger(logsdir='./tensorboard', val_data=val_generator)
    ... # define the model and generators
    
    # autoencoder is the model, note how callback is suppiled to fit_generator
    autoencoder.fit_generator(generator=train_generator,
                        validation_data=val_generator,
                        callbacks=callback_image)
    

    In my case where I'm displaying attention maps (as heatmaps) to tensorboard, this is the output.

    tensorboard

提交回复
热议问题