Keras Tensorflow - Exception while predicting from multiple threads

后端 未结 1 582
小蘑菇
小蘑菇 2020-12-05 15:46

I am using keras 2.0.8 with tensorflow 1.3.0 backend.

I am loading a model in the class init and then use it to predict multithreaded.



        
相关标签:
1条回答
  • 2020-12-05 16:03

    Make sure you finish the graph creation before creating the other threads.

    Calling finalize() on the graph may help you with that.

    def __init__(self, model_path):
            self.cnn_model = load_model(model_path)
            self.session = K.get_session()
            self.graph = tf.get_default_graph()
            self.graph.finalize()
    

    Update 1: finalize() will make your graph read-only so it can be safely used in multiple threads. As a side effect, it will help you find unintentional behavior and sometimes memory leaks as it will throw an exception when you try to modify the graph.

    Imagine that you have a thread that does for instance one hot encoding of your inputs. (bad example:)

    def preprocessing(self, data):
        one_hot_data = tf.one_hot(data, depth=self.num_classes)
        return self.session.run(one_hot_data)
    

    If you print the amount of objects in the graph you will notice that it will increase over time

    # amount of nodes in tf graph
    print(len(list(tf.get_default_graph().as_graph_def().node)))
    

    But if you define the graph first that won't be the case (slightly better code):

    def preprocessing(self, data):
        # run pre-created operation with self.input as placeholder
        return self.session.run(self.one_hot_data, feed_dict={self.input: data})
    

    Update 2: According to this thread you need to call model._make_predict_function() on a keras model before doing multithreading.

    Keras builds the GPU function the first time you call predict(). That way, if you never call predict, you save some time and resources. However, the first time you call predict is slightly slower than every other time.

    The updated code:

    def __init__(self, model_path):
        self.cnn_model = load_model(model_path)
        self.cnn_model._make_predict_function() # have to initialize before threading
        self.session = K.get_session()
        self.graph = tf.get_default_graph() 
        self.graph.finalize() # make graph read-only
    

    Update 3: I did a proof of concept of a warming up, because _make_predict_function() doesn't seems to work as expected. First I created a dummy model:

    import tensorflow as tf
    from keras.layers import *
    from keras.models import *
    
    model = Sequential()
    model.add(Dense(256, input_shape=(2,)))
    model.add(Dense(1, activation='softmax'))
    
    model.compile(loss='mean_squared_error', optimizer='adam')
    
    model.save("dummymodel")
    

    Then in another script I loaded that model and made it run on multiple threads

    import tensorflow as tf
    from keras import backend as K
    from keras.models import load_model
    import threading as t
    import numpy as np
    
    K.clear_session()
    
    class CNN:
        def __init__(self, model_path):
    
            self.cnn_model = load_model(model_path)
            self.cnn_model.predict(np.array([[0,0]])) # warmup
            self.session = K.get_session()
            self.graph = tf.get_default_graph()
            self.graph.finalize() # finalize
    
        def preproccesing(self, data):
            # dummy
            return data
    
        def query_cnn(self, data):
            X = self.preproccesing(data)
            with self.session.as_default():
                with self.graph.as_default():
                    prediction = self.cnn_model.predict(X)
            print(prediction)
            return prediction
    
    
    cnn = CNN("dummymodel")
    
    th = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
    th2 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
    th3 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
    th4 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
    th5 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
    th.start()
    th2.start()
    th3.start()
    th4.start()
    th5.start()
    
    th2.join()
    th.join()
    th3.join()
    th5.join()
    th4.join()
    

    Commenting the lines for the warmingup and finalize I was able to reproduce your first issue

    0 讨论(0)
提交回复
热议问题