How to find the variable names that are saved in a tensorflow checkpoint?

匿名 (未验证) 提交于 2019-12-03 01:25:01

问题:

I want to see the variables that are saved in a tensorflow checkpoint along with their values. How can I find the variable names that are saved in a tensorflow checkpoint?

EDIT :

I used tf.train.NewCheckpointReader which is explained here. But, it is not given in the documentation of tensorflow. Is there any other way?

`

    import tensorflow as tf     v0 = tf.Variable([[1, 2, 3], [4, 5, 6]], dtype=tf.float32, name="v0")     v1 = tf.Variable([[[1], [2]], [[3], [4]], [[5], [6]]], dtype=tf.float32,                      name="v1")     init_all_op = tf.initialize_all_variables()     save = tf.train.Saver({"v0": v0, "v1": v1})     checkpoint_path = os.path.join(model_dir, "model.ckpt")          with tf.Session() as sess:       sess.run(init_all_op)       # Saves a checkpoint.             save.save(sess, checkpoint_path)        # Creates a reader.       reader = tf.train.NewCheckpointReader(checkpoint_path)       print('reder:\n', reader)        # Verifies that the tensors exist.       print('is exist v0?', reader.has_tensor("v0"))       print('is exist v1?', reader.has_tensor("v1"))        # Verifies that debug string contains the right strings.       debug_string = reader.debug_string()       print('\n All Variables: \n', debug_string)        # Verifies get_variable_to_shape_map() returns the correct information.       var_map = reader.get_variable_to_shape_map()       print('\n All Variables information :\n', var_map)        # Verifies get_tensor() returns the tensor value.       v0_tensor = reader.get_tensor("v0")       v1_tensor = reader.get_tensor("v1")       print('\n   returns the v0 tensor value:\n', v0_tensor)       print('\n   returns the v1 tensor value:\n', v1_tensor) 

`

回答1:

You can use the inspect_checkpoint.py tool.



回答2:

Example usage:

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file checkpoint_path = os.path.join(model_dir, "model.ckpt")  # List ALL tensors example output: v0/Adam (DT_FLOAT) [3,3,1,80] print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='')  # List contents of v0 tensor. # Example output: tensor_name:  v0 [[[[  9.27958265e-02   7.40226209e-02   4.52989563e-02   3.15700471e-02 print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v0')  # List contents of v1 tensor. print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='v1') 

Update: all_tensors argument was added to print_tensors_in_checkpoint_file since Tensorflow 0.12.0-rc0 so you may need to add all_tensors=False or all_tensors=True if required.

Alternative method:

from tensorflow.python import pywrap_tensorflow checkpoint_path = os.path.join(model_dir, "model.ckpt") reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() for key in var_to_shape_map:     print("tensor_name: ", key)     print(reader.get_tensor(key)) # Remove this is you want to print only variable names 

Hope it helps.



回答3:

Adding to above answer :

If model is saved using V2 format

model-10000.data-00000-of-00001 model-10000.index model-10000.meta 

Your checkpoint input name should only be the prefix

print_tensors_in_checkpoint_file(file_name='/home/RNN/models/model_10000', tensor_name='',all_tensors=True) 

source: by @LingjiaDeng at https://github.com/tensorflow/tensorflow/issues/7696



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!