No variable to save error in Tensorflow

前端 未结 2 566
孤街浪徒
孤街浪徒 2020-12-24 08:09

I am trying to save the model and then reuse it for classifying my images but unfortunately i am getting errors in restoring the model that i have saved.

The

相关标签:
2条回答
  • 2020-12-24 09:01

    The error here is quite subtle. In In[8] you create a tf.Graph called graph and set it as default for the with graph.as_default(): block. This means that all of the variables are created in graph, and if you print graph.all_variables() you should see a list of your variables.

    However, you exit the with block before creating (i) the tf.Session, and (ii) the tf.train.Saver. This means that the session and saver are created in a different graph (the global default tf.Graph that is used when you don't explicitly create one and set it as default), which doesn't contain any variables—or any nodes at all.

    There are at least two solutions:

    1. As Yaroslav suggests, you can write your program without using the with graph.as_default(): block, which avoids the confusion with multiple graphs. However, this can lead to name collisions between different cells in your IPython notebook, which is awkward when using the tf.train.Saver, since it uses the name property of a tf.Variable as the key in the checkpoint file.

    2. You can create the saver inside the with graph.as_default(): block, and create the tf.Session with an explicit graph, as follows:

      with graph.as_default():
          # [Variable and model creation goes here.]
      
          saver = tf.train.Saver()  # Gets all variables in `graph`.
      
      with tf.Session(graph=graph) as sess:
          saver.restore(sess)
          # Do some work with the model....
      

      Alternatively, you can create the tf.Session inside the with graph.as_default(): block, in which case it will use graph for all of its operations.

    0 讨论(0)
  • 2020-12-24 09:04

    You are creating a new session in In[17] which wipes your variables. Also, you don't need to use with blocks if you only have one default graph and one default session, you can instead do something like this

    sess = tf.InteractiveSession()
    layer1_weights = tf.Variable(tf.truncated_normal(
      [patch_size, patch_size, num_channels, depth], stddev=0.1),name="layer1_weights")
    tf.train.Saver().restore(sess, "/tmp/model.ckpt")
    
    0 讨论(0)
提交回复
热议问题