TensorFlow: How to predict from a SavedModel?

前端 未结 4 736
南方客
南方客 2020-12-25 15:41

I have exported a SavedModel and now I with to load it back in and make a prediction. It was trained with the following features and labels:

F1          


        
4条回答
  •  时光取名叫无心
    2020-12-25 16:05

    Assuming you want predictions in Python, SavedModelPredictor is probably the easiest way to load a SavedModel and get predictions. Suppose you save your model like so:

    # Build the graph
    f1 = tf.placeholder(shape=[], dtype=tf.float32)
    f2 = tf.placeholder(shape=[], dtype=tf.float32)
    f3 = tf.placeholder(shape=[], dtype=tf.float32)
    l1 = tf.placeholder(shape=[], dtype=tf.float32)
    output = build_graph(f1, f2, f3, l1)
    
    # Save the model
    inputs = {'F1': f1, 'F2': f2, 'F3': f3, 'L1': l1}
    outputs = {'output': output_tensor}
    tf.contrib.simple_save(sess, export_dir, inputs, outputs)
    

    (The inputs can be any shape and don't even have to be placeholders nor root nodes in the graph).

    Then, in the Python program that will use the SavedModel, we can get predictions like so:

    from tensorflow.contrib import predictor
    
    predict_fn = predictor.from_saved_model(export_dir)
    predictions = predict_fn(
        {"F1": 1.0, "F2": 2.0, "F3": 3.0, "L1": 4.0})
    print(predictions)
    

    This answer shows how to get predictions in Java, C++, and Python (despite the fact that the question is focused on Estimators, the answer actually applies independently of how the SavedModel is created).

提交回复
热议问题