How to get Graph (or GraphDef) from a given Model?

旧街凉风 提交于 2021-01-24 14:52:46

问题


I have a big model defined using Tensorflow 2 with Keras. The model works well in Python. Now, I want to import it into C++ project.

Inside my C++ project, I use TF_GraphImportGraphDef function. It works well if I prepare *.pb file using the following code:

    with open('load_model.pb', 'wb') as f:
        f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())

I've tried this code on a simple network written using Tensorflow 1 (using tf.compat.v1.* functions). It works well.

Now I want to export my big model (mentioned at the beginning, written using Tensorflow 2) to the C++ project. To do this, I need to get a Graph or GraphDef object from my model. The question is: how to do this? I didn't find any property or function to get it.

I've also tried to use tf.saved_model.save(model, 'model') to save the whole model. It generates a directory with different files including saved_model.pb file. Unfortunately, when I try to load this file in C++ using TF_GraphImportGraphDef function, the program throws an exception.


回答1:


The protocol buffers file generated by tf.saved_model.save does not contain a GraphDef message, but a SavedModel. You could traverse that SavedModel in Python to get the embedded graph(s) in it, but that would not immediately work as a frozen graph, so getting it right would probably be difficult. Instead of that, the C++ API now includes a LoadSavedModel call that allows you to load a whole saved model from a directory. It should look some like this:

#include <iostream>
#include <...>  // Add necessary TF include directives

using namespace std;
using namespace tensorflow;

int main()
{
    // Path to saved model directory
    const string export_dir = "...";
    // Load model
    Status s;
    SavedModelBundle bundle;
    SessionOptions session_options;
    RunOptions run_options;
    s = LoadSavedModel(session_options, run_options, export_dir,
                       // default "serve" tag set by tf.saved_model.save
                       {"serve"}, &bundle));
    if (!.ok())
    {
        cerr << "Could not load model: " << s.error_message() << endl;
        return -1;
    }
    // Model is loaded
    // ...
    return 0;
}

From here, you could do different things. Maybe you would be most comfortable converting that saved model into a frozen graph, using FreezeSavedModel, which should allow you to do things pretty much as you were doing them before:

GraphDef frozen_graph_def;
std::unordered_set<string> inputs;
std::unordered_set<string> outputs;
s = FreezeSavedModel(bundle, &frozen_graph_def,
                     &inputs, &outputs));
if (!s.ok())
{
    cerr << "Could not freeze model: " << s.error_message() << endl;
    return -1;
}

Otherwise, you can work directly with the saved model object:

// Default "serving_default" signature name set by tf.saved_model_save
const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");
// Get input and output names (different from layer names)
// Key is input and output layer names
const string input_name = signature_def.inputs().at("my_input").name();
const string output_name = signature_def.inputs().at("my_output").name();
// Run model
Tensor input = ...;
std::vector<Tensor> outputs;
s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
if (!s.ok())
{
    cerr << "Error running model: " << s.error_message() << endl;
    return -1;
}
// Get result
Tensor& output = outputs[0];



回答2:


I found the following solution to the question:

g = tf.Graph()
with g.as_default():

    # Create model
    inputs = tf.keras.Input(...) 
    x = tf.keras.layers.Conv2D(1, (1,1), padding='same')(inputs)
    # Done creating model

    # Optionally get graph operations
    ops = g.get_operations()
    for op in ops:
        print(op.name, op.type)

    # Save graph
    tf.io.write_graph(g.as_graph_def(), 'path', 'filename.pb', as_text=False)



来源:https://stackoverflow.com/questions/63181951/how-to-get-graph-or-graphdef-from-a-given-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!