Restore subset of variables in Tensorflow

后端 未结 4 1633
名媛妹妹
名媛妹妹 2020-12-08 01:17

I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.

self.G, s         


        
4条回答
  •  清歌不尽
    2020-12-08 01:47

    Inspired by @mrry, I propose a solution for this problem. To make it clear, I formulate the problem as restoring a subset of the variable from the checkpoint, when the model is built on a pre-trained model. First, we should use print_tensors_in_checkpoint_file function from the library inspect_checkpoint or just simply extract this function by:

    from tensorflow.python import pywrap_tensorflow
    def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
        varlist=[]
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
          var_to_shape_map = reader.get_variable_to_shape_map()
          for key in sorted(var_to_shape_map):
            varlist.append(key)
        return varlist
    varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None)
    

    Then we use tf.get_collection() just like @mrry saied:

    variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    

    Finally, we can initialize the saver by:

    saver = tf.train.Saver(variable[:len(varlist)])
    

    The complete version can be found at my github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py

    In my situation, the new variables are added at the end of the model, so I can simply use [:length()] to identify the needed variables, for a more complex situation, you might have to do some hand-alignment work or write a simple string matching function to determine the required variables.

提交回复
热议问题