How do I find the variable names and values that are saved in a checkpoint?

后端 未结 5 479
执笔经年
执笔经年 2020-12-12 18:26

I want to see the variables that are saved in a TensorFlow checkpoint along with their values. How can I find the variable names that are saved in a TensorFlow checkpoint?

相关标签:
5条回答
  • 2020-12-12 18:43

    You can use the inspect_checkpoint.py tool.

    So, for example, if you stored the checkpoint in the current directory, then you can print the variables and their values as follows

    import tensorflow as tf
    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    
    
    latest_ckp = tf.train.latest_checkpoint('./')
    print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
    
    0 讨论(0)
  • 2020-12-12 18:47

    Adding more parameter details to print_tensors_in_checkpoint_file

    file_name: not a physical file, just the prefix of filenames

    If no tensor_name is provided, prints the tensor names and shapes in the checkpoint file. If tensor_name is provided, prints the content of the tensor.(inspect_checkpoint.py)

    If all_tensor_names is True, Prints all the tensor names

    If all_tensor is 'True`, Prints all the tensor names and the corresponding content.

    N.B. all_tensor and all_tensor_names will override tensor_name

    0 讨论(0)
  • 2020-12-12 18:51

    A few more details.

    If your model is saved using V2 format, for example, if we have the following files in the directory /my/dir/

    model-10000.data-00000-of-00001
    model-10000.index
    model-10000.meta
    

    then the file_name parameter should only be the prefix, that is

    print_tensors_in_checkpoint_file(file_name='/my/dir/model_10000', tensor_name='', all_tensors=True)
    

    See https://github.com/tensorflow/tensorflow/issues/7696 for a discussion.

    0 讨论(0)
  • 2020-12-12 18:52

    An update to the answers mentioned above

    For latest Tensorflow versions (verified on TF 1.13+), a cleaner way to do is as follows

    ckpt_reader = tf.train.load_checkpoint(ckpt_dir_or_file)
    value = ckpt_reader.get_tensor(name_of_the_tensor)
    

    The name_of_the_tensor should correspond the variable name (whose value you're trying to inspect). To get a list of variable names and shapes in a checkpoint, you can check via

    vars_list = tf.train.list_variables(ckpt_dir_or_file)
    
    0 讨论(0)
  • 2020-12-12 19:00

    Example usage:

    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    import os
    checkpoint_path = os.path.join(model_dir, "model.ckpt")
    
    # List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80]
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')
    
    # List contents of v0 tensor.
    # Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')
    
    # List contents of v1 tensor.
    print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1')
    

    Update: all_tensors argument was added to print_tensors_in_checkpoint_file since Tensorflow 0.12.0-rc0 so you may need to add all_tensors=False or all_tensors=True if required.

    Alternative method:

    from tensorflow.python import pywrap_tensorflow
    import os
    
    checkpoint_path = os.path.join(model_dir, "model.ckpt")
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    
    for key in var_to_shape_map:
        print("tensor_name: ", key)
        print(reader.get_tensor(key)) # Remove this is you want to print only variable names
    

    Hope it helps.

    0 讨论(0)
提交回复
热议问题