How to keep tensorflow session open between predictions? Loading from SavedModel

后端 未结 3 774
庸人自扰
庸人自扰 2021-02-05 18:00

I trained a tensorflow model that i\'d like to run predictions on from numpy arrays. This is for image processing within videos. I will pass the images to the model as they happ

3条回答
  •  既然无缘
    2021-02-05 18:22

    Others have explained why you can't put your session in a with statement in the constructor.

    The reason you see different behavior when using the context manager vs. not is because tf.saved_model.loader.load has some weird interactions between the default graph and the graph that is part of the session.

    The solution is simple; don't pass a graph to session if you're not using it in a with block:

    sess=tf.Session()
    tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.SERVING], "model")
    

    Here's some example code for a class to do predictions:

    class Model(object):
    
      def __init__(self, model_path):
        # Note, if you don't want to leak this, you'll want to turn Model into
        # a context manager. In practice, you probably don't have to worry
        # about it.
        self.session = tf.Session()
    
        tf.saved_model.loader.load(
            self.session,
            [tf.saved_model.tag_constants.SERVING],
            model_path)
    
        self.softmax_tensor = self.session.graph.get_tensor_by_name('final_ops/softmax:0')
    
      def predict(self, images):
        predictions = self.session.run(self.softmax, {'Placeholder:0': images})
        # TODO: convert to human-friendly labels
        return predictions
    
    
    images = [tf.gfile.FastGFile(f, 'rb').read() for f in glob.glob("*.jpg")]
    model = Model('model_path')
    print(model.predict(images))
    
    # Alternatively (uses less memory, but has lower throughput):
    for f in glob.glob("*.jpg"):
      print(model.predict([tf.gfile.FastGFile(f, 'rb').read()]))
    

提交回复
热议问题