Python/Keras - How to access each epoch prediction?

前端 未结 2 1996
灰色年华
灰色年华 2020-12-31 03:30

I\'m using Keras to predict a time series. As standard I\'m using 20 epochs. I want to check if my model is learning well, by predicting for each one of the 20 epochs.

<
2条回答
  •  陌清茗
    陌清茗 (楼主)
    2020-12-31 04:08

    The following code will do the desired job:

    import tensorflow as tf
    import keras
    
    # define your custom callback for prediction
    class PredictionCallback(tf.keras.callbacks.Callback):    
      def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.validation_data[0])
        print('prediction: {} at epoch: {}'.format(y_pred, epoch))
    
    # ...
    
    # register the callback before training starts
    model.fit(X_train, y_train, batch_size=32, epochs=25, 
              validation_data=(X_valid, y_valid), 
              callbacks=[PredictionCallback()])
    

提交回复
热议问题