how to predict with .meta and checkpoint files in tensorflow?

时光怂恿深爱的人放手 提交于 2019-11-30 10:36:15

You should load the graph using tf.train.import_meta_graph() and then get the tensors using get_tensor_by_name(). You can try:

model_path = "model.ckpt"
detection_graph = tf.Graph()
with tf.Session(graph=detection_graph) as sess:
    # Load the graph with the trained states
    loader = tf.train.import_meta_graph(model_path+'.meta')
    loader.restore(sess, model_path)

    # Get the tensors by their variable name
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    scores = detection_graph.get_tensor_by_name('detection_scores:0')
    ...
    # Make predictions
    _boxes, _scores = sess.run([boxes, scores], feed_dict={image_tensor: image_np_expanded}) 
Ren

Just for those who have the problem like wu ruize and CoupDeMistral:

But I got this error: "The name 'image_tensor:0' refers to a Tensor which does not exist. The operation, 'image_tensor', does not exist in the graph."

You need to name your tensor first before using detection_graph.get_tensor_by_name.

For example, something like this:

accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32),name='accuracy')

Notice that the tensor above has been named as 'accuracy'.

After that you can enjoy the restore operation by:

detection_graph.get_tensor_by_name('accuracy:0')

Have fun now :P!

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!