Tensorflow: restoring a graph and model then running evaluation on a single image

前端 未结 4 1400
后悔当初
后悔当初 2020-12-07 10:32

I think it would be immensely helpful to the Tensorflow community if there was a well-documented solution to the crucial task of testing a single new image against the model

4条回答
  •  时光取名叫无心
    2020-12-07 10:40

    Here's how I ran a single image at a time. I'll admit it seems a bit hacky with the reuse of getting the scope.

    This is a helper function

    def restore_vars(saver, sess, chkpt_dir):
        """ Restore saved net, global score and step, and epsilons OR
        create checkpoint directory for later storage. """
        sess.run(tf.initialize_all_variables())
    
        checkpoint_dir = chkpt_dir
    
        if not os.path.exists(checkpoint_dir):
            try:
                os.makedirs(checkpoint_dir)
            except OSError:
                pass
    
        path = tf.train.get_checkpoint_state(checkpoint_dir)
        #print("path1 = ",path)
        #path = tf.train.latest_checkpoint(checkpoint_dir)
        print(checkpoint_dir,"path = ",path)
        if path is None:
            return False
        else:
            saver.restore(sess, path.model_checkpoint_path)
            return True
    

    Here is the main part of the code that runs a single image at a time within the for loop.

    to_restore = True
    with tf.Session() as sess:
    
        for i in test_img_idx_set:
    
                # Gets the image
                images = get_image(i)
                images = np.asarray(images,dtype=np.float32)
                images = tf.convert_to_tensor(images/255.0)
                # resize image to whatever you're model takes in
                images = tf.image.resize_images(images,256,256)
                images = tf.reshape(images,(1,256,256,3))
                images = tf.cast(images, tf.float32)
    
                saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
    
                #print("infer")
                with tf.variable_scope(tf.get_variable_scope()) as scope:
                    if to_restore:
                        logits = inference(images)
                    else:
                        scope.reuse_variables()
                        logits = inference(images)
    
    
                if to_restore:
                    restored = restore_vars(saver, sess,FLAGS.train_dir)
                    print("restored ",restored)
                    to_restore = False
    
                logit_val = sess.run(logits)
                print(logit_val)
    

    Here is an alternative implementation to the above using place holders it's a bit cleaner in my opinion. but I'll leave the above example for historical reasons.

    imgs_place = tf.placeholder(tf.float32, shape=[my_img_shape_put_here])
    images = tf.reshape(imgs_place,(1,256,256,3))
    
    saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
    
    #print("infer")
    logits = inference(images)
    
    restored = restore_vars(saver, sess,FLAGS.train_dir)
    print("restored ",restored)
    
    with tf.Session() as sess:
        for i in test_img_idx_set:
            logit_val = sess.run(logits,feed_dict={imgs_place=i})
            print(logit_val)
    

提交回复
热议问题