Restore subset of variables in Tensorflow

后端 未结 4 1636
名媛妹妹
名媛妹妹 2020-12-08 01:17

I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.

self.G, s         


        
4条回答
  •  暗喜
    暗喜 (楼主)
    2020-12-08 02:09

    I had a similar problem when restoring only part of my variables from a checkpoint and some of the saved variables did not exist in the new model. Inspired by @Lidong answer I modified a little the reading function:

    def get_tensors_in_checkpoint_file(file_name,all_tensors=True,tensor_name=None):
    varlist=[]
    var_value =[]
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        varlist.append(key)
        var_value.append(reader.get_tensor(key))
    else:
        varlist.append(tensor_name)
        var_value.append(reader.get_tensor(tensor_name))
    return (varlist, var_value)
    

    and added a loading function:

    def build_tensors_in_checkpoint_file(loaded_tensors):
    full_var_list = list()
    # Loop all loaded tensors
    for i, tensor_name in enumerate(loaded_tensors[0]):
        # Extract tensor
        try:
            tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0")
        except:
            print('Not found: '+tensor_name)
        full_var_list.append(tensor_aux)
    return full_var_list
    

    Then you can simply load all common variables using:

    CHECKPOINT_NAME = path to save file
    restored_vars  = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME)
    tensors_to_load = build_tensors_in_checkpoint_file(restored_vars)
    loader = tf.train.Saver(tensors_to_load)
    loader.restore(sess, CHECKPOINT_NAME)
    

    Edit: I am using tensorflow 1.2

提交回复
热议问题