I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.
self.G, s
To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in the optional var_list
argument.
By default, a tf.train.Saver
will create ops that (i) save every variable in your graph when you call saver.save() and (ii) lookup (by name) every variable in the given checkpoint when you call saver.restore(). While this works for most common scenarios, you have to provide more information to work with specific subsets of the variables:
If you only want to restore a subset of the variables, you can get a list of these variables by calling tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX), assuming that you put the "g" network in a common with tf.name_scope(G_NETWORK_PREFIX): or tf.variable_scope(G_NETWORK_PREFIX): block. You can then pass this list to the tf.train.Saver
constructor.
If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list
argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name
property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable
objects (in the target graph).