I am trying to save the model and then reuse it for classifying my images but unfortunately i am getting errors in restoring the model that i have saved.
The
The error here is quite subtle. In In[8]
you create a tf.Graph called graph
and set it as default for the with graph.as_default(): block. This means that all of the variables are created in graph
, and if you print graph.all_variables()
you should see a list of your variables.
However, you exit the with
block before creating (i) the tf.Session, and (ii) the tf.train.Saver. This means that the session and saver are created in a different graph (the global default tf.Graph
that is used when you don't explicitly create one and set it as default), which doesn't contain any variables—or any nodes at all.
There are at least two solutions:
As Yaroslav suggests, you can write your program without using the with graph.as_default():
block, which avoids the confusion with multiple graphs. However, this can lead to name collisions between different cells in your IPython notebook, which is awkward when using the tf.train.Saver
, since it uses the name
property of a tf.Variable
as the key in the checkpoint file.
You can create the saver inside the with graph.as_default():
block, and create the tf.Session
with an explicit graph, as follows:
with graph.as_default():
# [Variable and model creation goes here.]
saver = tf.train.Saver() # Gets all variables in `graph`.
with tf.Session(graph=graph) as sess:
saver.restore(sess)
# Do some work with the model....
Alternatively, you can create the tf.Session
inside the with graph.as_default():
block, in which case it will use graph
for all of its operations.
You are creating a new session in In[17]
which wipes your variables. Also, you don't need to use with
blocks if you only have one default graph and one default session, you can instead do something like this
sess = tf.InteractiveSession()
layer1_weights = tf.Variable(tf.truncated_normal(
[patch_size, patch_size, num_channels, depth], stddev=0.1),name="layer1_weights")
tf.train.Saver().restore(sess, "/tmp/model.ckpt")