How to convert .ckpt to .pb?

前端 未结 3 902
渐次进展
渐次进展 2020-12-30 13:24

I am new to deep learning and I want to use a pretrained (EAST) model to serve from the AI Platform Serving, I have these files made available by the developer:

相关标签:
3条回答
  • 2020-12-30 13:55

    Here's the code to convert the checkpoint to SavedModel

    import os
    import tensorflow as tf
    
    trained_checkpoint_prefix = 'models/model.ckpt-49491'
    export_dir = os.path.join('export_dir', '0')
    
    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
        # Restore from checkpoint
        loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
        loader.restore(sess, trained_checkpoint_prefix)
    
        # Export checkpoint to SavedModel
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(sess,
                                             [tf.saved_model.TRAINING, tf.saved_model.SERVING],
                                             strip_default_attrs=True)
        builder.save()                
    
    0 讨论(0)
  • 2020-12-30 14:01

    If you specify INPUT_TYPE as image_tensor and PIPELINE_CONFIG_PATH as your config file with this command.

    python object_detection/export_inference_graph.py \
    --input_type=${INPUT_TYPE} \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
    --output_directory=${EXPORT_DIR}
    

    you can get your model in 3 formats in your export dir;

    • frozen_graph.pb
    • savedmodel.pb
    • checkpoint

    for more info https://github.com/tensorflow/models/tree/master/research/object_detection

    0 讨论(0)
  • 2020-12-30 14:08

    Following the answer of @Puneith Kaul, here is the syntax for tensorflow version 1.7:

    import os
    import tensorflow as tf
    
    export_dir = 'export_dir' 
    trained_checkpoint_prefix = 'models/model.ckpt'
    graph = tf.Graph()
    loader = tf.train.import_meta_graph(trained_checkpoint_prefix + ".meta" )
    sess = tf.Session()
    loader.restore(sess,trained_checkpoint_prefix)
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)
    builder.save()
    
    0 讨论(0)
提交回复
热议问题