Keras (Tensorflow backend) Error - Tensor input_1:0, specified in either feed_devices or fetch_devices was not found in the Graph

前端 未结 3 613
长情又很酷
长情又很酷 2020-12-15 22:59

When trying to predict using a simple model I\'ve previously trained I get the following error:

Tensor input_1:0, specified in either feed_devices or fetch_d

相关标签:
3条回答
  • 2020-12-15 23:27

    It's a very common issue one faces while deploying multiple models especially in flask apps. The best way to deal with this is to set a session save the graph before loading any keras model. This specifically helps if you trying to use pickled models to predict labels.

    Steps:

    • Just save the session and graph before loading the model.
    • In a different thread, load these saved variables and then use the model's predict function.

    Sample Full Code:

    Main Class

    import pickle
    import tensorflow as tf
    from tensorflow.python.keras.backend import set_session
    
    # your other file/class
    import UserDefinedClass
    
    class Main(object):
    
        def __init__(self):
            return
    
        def load_models(self):
    
            # Loading a generic model
            model1 = pickle.load(open(model1_path, "rb"))
    
            # Loading a keras model
            session = tf.Session()
            graph = tf.get_default_graph()
            set_session(session)
            model2 = pickle.load(open(model2_path, "rb"))
    
            # Pass 'session', 'graph' to other classes
            userClassOBJ = UserDefinedClass(session, graph, model1, model2) 
            return
    
        def run(self, X):
            # X is input
            GenericLabels, KerasLables = userClassOBJ.SomeFunction(X)
    

    Some other file/class in different thread or flask_call:

    from tensorflow.python.keras.backend import set_session
    
    class UserDefinedClass(object):
    
        def __init__(self, session, graph, model1, model2):
            self.session = session
            self.graph = graph
            self.Generic_model = model1
            self.Keras_model = model2
            return
    
        def SomeFunction(self, X):
    
            # Generic model prediction
            Generic_labels = self.Generic_model.predict(X)
            print("Generic model prediction done!!")
    
            # Keras model prediciton
            with self.graph.as_default():
                set_session(self.session)
                Keras_labels = self.Keras_model.predict(X, verbose=0)
                print("Keras model prediction done!!")
            return Generic_labels, Keras_labels 
    
    0 讨论(0)
  • 2020-12-15 23:29

    OK, after a lot of pain and suffering and diving into the bowels of tensorflow I found the following:

    Although the model has a Session and Graph, in some tensorflow methods, the default Session and Graph are used. To fix this I had to explicity say that I wanted to use both my Session and my Graph as the default:

    with session.as_default():
        with session.graph.as_default():
    

    Full Code:

    from tensorflow import keras
    import tensorflow as tf
    import numpy as np
    import log
    
    config = tf.ConfigProto(
        device_count={'GPU': 1},
        intra_op_parallelism_threads=1,
        allow_soft_placement=True
    )
    
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.6
    
    session = tf.Session(config=config)
    
    keras.backend.set_session(session)
    
    seatbelt_model = keras.models.load_model(filepath='./seatbelt.h5')
    
    SEATBEL_INPUT_SHAPE = (-1, 120, 160, 1)
    
    def predict_seatbelt(image_arr):
        try:
            with session.as_default():
                with session.graph.as_default():
                    image_arr = np.array(image_arr).reshape(SEATBEL_INPUT_SHAPE)
                    predicted_labels = seatbelt_model.predict(image_arr, verbose=1)
                    return predicted_labels
        except Exception as ex:
            log.log('Seatbelt Prediction Error', ex, ex.__traceback__.tb_lineno)
    
    0 讨论(0)
  • 2020-12-15 23:37

    I faced the same issue. I was working on TensorFlow 1.0 so I thought to upgrade it to the latest version (2.1) and then my code worked perfectly.

    0 讨论(0)
提交回复
热议问题