I think it would be immensely helpful to the Tensorflow community if there was a well-documented solution to the crucial task of testing a single new image against the model
Here's how I ran a single image at a time. I'll admit it seems a bit hacky with the reuse of getting the scope.
This is a helper function
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
sess.run(tf.initialize_all_variables())
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
try:
os.makedirs(checkpoint_dir)
except OSError:
pass
path = tf.train.get_checkpoint_state(checkpoint_dir)
#print("path1 = ",path)
#path = tf.train.latest_checkpoint(checkpoint_dir)
print(checkpoint_dir,"path = ",path)
if path is None:
return False
else:
saver.restore(sess, path.model_checkpoint_path)
return True
Here is the main part of the code that runs a single image at a time within the for loop.
to_restore = True
with tf.Session() as sess:
for i in test_img_idx_set:
# Gets the image
images = get_image(i)
images = np.asarray(images,dtype=np.float32)
images = tf.convert_to_tensor(images/255.0)
# resize image to whatever you're model takes in
images = tf.image.resize_images(images,256,256)
images = tf.reshape(images,(1,256,256,3))
images = tf.cast(images, tf.float32)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
#print("infer")
with tf.variable_scope(tf.get_variable_scope()) as scope:
if to_restore:
logits = inference(images)
else:
scope.reuse_variables()
logits = inference(images)
if to_restore:
restored = restore_vars(saver, sess,FLAGS.train_dir)
print("restored ",restored)
to_restore = False
logit_val = sess.run(logits)
print(logit_val)
Here is an alternative implementation to the above using place holders it's a bit cleaner in my opinion. but I'll leave the above example for historical reasons.
imgs_place = tf.placeholder(tf.float32, shape=[my_img_shape_put_here])
images = tf.reshape(imgs_place,(1,256,256,3))
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
#print("infer")
logits = inference(images)
restored = restore_vars(saver, sess,FLAGS.train_dir)
print("restored ",restored)
with tf.Session() as sess:
for i in test_img_idx_set:
logit_val = sess.run(logits,feed_dict={imgs_place=i})
print(logit_val)