TensorFlow: How to predict from a SavedModel?

前端 未结 4 734
南方客
南方客 2020-12-25 15:41

I have exported a SavedModel and now I with to load it back in and make a prediction. It was trained with the following features and labels:

F1          


        
相关标签:
4条回答
  • 2020-12-25 15:45

    For anyone who needs a working example of saving a trained canned model and serving it without tensorflow serving ,I have documented here https://github.com/tettusud/tensorflow-examples/tree/master/estimators

    1. You can create a predictor from tf.tensorflow.contrib.predictor.from_saved_model( exported_model_path)
    2. Prepare input

      tf.train.Example( 
          features= tf.train.Features(
              feature={
                  'x': tf.train.Feature(
                       float_list=tf.train.FloatList(value=[6.4, 3.2, 4.5, 1.5])
                  )     
              }
          )    
      )
      

    Here x is the name of the input that was given in input_receiver_function at the time of exporting. for eg:

    feature_spec = {'x': tf.FixedLenFeature([4],tf.float32)}
    
    def serving_input_receiver_fn():
        serialized_tf_example = tf.placeholder(dtype=tf.string,
                                               shape=[None],
                                               name='input_tensors')
        receiver_tensors = {'inputs': serialized_tf_example}
        features = tf.parse_example(serialized_tf_example, feature_spec)
        return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
    
    0 讨论(0)
  • 2020-12-25 15:57

    The constructor of tf.estimator.DNNClassifier has an argument called warm_start_from. You can give it the SavedModel folder name and it will recover your session.

    0 讨论(0)
  • 2020-12-25 16:05

    Assuming you want predictions in Python, SavedModelPredictor is probably the easiest way to load a SavedModel and get predictions. Suppose you save your model like so:

    # Build the graph
    f1 = tf.placeholder(shape=[], dtype=tf.float32)
    f2 = tf.placeholder(shape=[], dtype=tf.float32)
    f3 = tf.placeholder(shape=[], dtype=tf.float32)
    l1 = tf.placeholder(shape=[], dtype=tf.float32)
    output = build_graph(f1, f2, f3, l1)
    
    # Save the model
    inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
    outputs = {'output': output_tensor}
    tf.contrib.simple_save(sess, export_dir, inputs, outputs)
    

    (The inputs can be any shape and don't even have to be placeholders nor root nodes in the graph).

    Then, in the Python program that will use the SavedModel, we can get predictions like so:

    from tensorflow.contrib import predictor
    
    predict_fn = predictor.from_saved_model(export_dir)
    predictions = predict_fn(
        {"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
    print(predictions)
    

    This answer shows how to get predictions in Java, C++, and Python (despite the fact that the question is focused on Estimators, the answer actually applies independently of how the SavedModel is created).

    0 讨论(0)
  • 2020-12-25 16:06

    Once the graph is loaded, it is available in the current context and you can feed input data through it to obtain predictions. Each use-case is rather different, but the addition to your code will look something like this:

    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(
            sess,
            [tf.saved_model.tag_constants.SERVING],
            "/job/export/Servo/1503723455"
        )
    
        prediction = sess.run(
            'prefix/predictions/Identity:0',
            feed_dict={
                'Placeholder:0': [20.9],
                'Placeholder_1:0': [1.8],
                'Placeholder_2:0': [0.9]
            }
        )
    
        print(prediction)
    

    Here, you need to know the names of what your prediction inputs will be. If you did not give them a nave in your serving_fn, then they default to Placeholder_n, where n is the nth feature.

    The first string argument of sess.run is the name of the prediction target. This will vary based on your use case.

    0 讨论(0)
提交回复
热议问题