Tensorflow: how to save/restore a model?

前端 未结 26 2974
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  后悔当初
    2020-11-21 12:04

    Following @Vishnuvardhan Janapati 's answer, here is another way to save and reload model with custom layer/metric/loss under TensorFlow 2.0.0

    import tensorflow as tf
    from tensorflow.keras.layers import Layer
    from tensorflow.keras.utils.generic_utils import get_custom_objects
    
    # custom loss (for example)  
    def custom_loss(y_true,y_pred):
      return tf.reduce_mean(y_true - y_pred)
    get_custom_objects().update({'custom_loss': custom_loss}) 
    
    # custom loss (for example) 
    class CustomLayer(Layer):
      def __init__(self, ...):
          ...
      # define custom layer and all necessary custom operations inside custom layer
    
    get_custom_objects().update({'CustomLayer': CustomLayer})  
    

    In this way, once you have executed such codes, and saved your model with tf.keras.models.save_model or model.save or ModelCheckpoint callback, you can re-load your model without the need of precise custom objects, as simple as

    new_model = tf.keras.models.load_model("./model.h5"})
    

提交回复
热议问题