问题
I'm running a web server using flask and the error comes up when I try to use vgg16, which is the global variable for keras' pre-trained VGG16 model. I have no idea why this error rises or whether it has anything to do with the Tensorflow backend. Here is my code:
vgg16 = VGG16(weights='imagenet', include_top=True)
def getVGG16Prediction(img_path):
global vgg16
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
pred = vgg16.predict(x)
return x, sort(decode_predictions(pred, top=3)[0])
@app.route("/uploadMultipleImages", methods=["POST"])
def uploadMultipleImages():
uploaded_files = request.files.getlist("file[]")
for file in uploaded_files:
path = os.path.join(STATIC_PATH, file.filename)
pInput, result = getVGG16Prediction(path)
Here is the full error:
Any comment or suggestion is greatly appreciated. Thank you.
回答1:
Take a look at avital
's answer on this github issue. Quoting the relevant part here:
Right after loading or constructing your model, save the TensorFlow graph:
graph = tf.get_default_graph()
In the other thread (or perhaps in an asynchronous event handler), do:
global graph with graph.as_default(): (... do inference here ...)
I modified this a bit and stored the graph in my app's config object instead of making it a global.
The TensorFlow documentation for get_default_graph
explains why this is necessary:
NOTE: The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that thread, you must explicitly add a with g.as_default(): in that thread's function.
来源:https://stackoverflow.com/questions/42013138/valueerror-tensor-tensor-is-not-an-element-of-this-graph-when-using-globa