Tensorflow: how to save/restore a model?

前端 未结 26 3031
迷失自我
迷失自我 2020-11-21 11:37

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
26条回答
  •  借酒劲吻你
    2020-11-21 12:15

    My environment: Python 3.6, Tensorflow 1.3.0

    Though there have been many solutions, most of them is based on tf.train.Saver. When we load a .ckpt saved by Saver, we have to either redefine the tensorflow network or use some weird and hard-remembered name, e.g. 'placehold_0:0','dense/Adam/Weight:0'. Here I recommend to use tf.saved_model, one simplest example given below, your can learn more from Serving a TensorFlow Model:

    Save the model:

    import tensorflow as tf
    
    # define the tensorflow network and do some trains
    x = tf.placeholder("float", name="x")
    w = tf.Variable(2.0, name="w")
    b = tf.Variable(0.0, name="bias")
    
    h = tf.multiply(x, w)
    y = tf.add(h, b, name="y")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    # save the model
    export_path =  './savedmodel'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    
    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'x_input': tensor_info_x},
          outputs={'y_output': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    
    builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              prediction_signature 
      },
      )
    builder.save()
    

    Load the model:

    import tensorflow as tf
    sess=tf.Session() 
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    input_key = 'x_input'
    output_key = 'y_output'
    
    export_path =  './savedmodel'
    meta_graph_def = tf.saved_model.loader.load(
               sess,
              [tf.saved_model.tag_constants.SERVING],
              export_path)
    signature = meta_graph_def.signature_def
    
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name
    
    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)
    
    y_out = sess.run(y, {x: 3.0})
    

提交回复
热议问题