I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.
self.G, s
Inspired by @mrry, I propose a solution for this problem. To make it clear, I formulate the problem as restoring a subset of the variable from the checkpoint, when the model is built on a pre-trained model. First, we should use print_tensors_in_checkpoint_file function from the library inspect_checkpoint or just simply extract this function by:
from tensorflow.python import pywrap_tensorflow
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
varlist=[]
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
varlist.append(key)
return varlist
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None)
Then we use tf.get_collection() just like @mrry saied:
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
Finally, we can initialize the saver by:
saver = tf.train.Saver(variable[:len(varlist)])
The complete version can be found at my github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py
In my situation, the new variables are added at the end of the model, so I can simply use [:length()] to identify the needed variables, for a more complex situation, you might have to do some hand-alignment work or write a simple string matching function to determine the required variables.