TensorFlow 0.12 Model Files

后端 未结 2 1606
自闭症患者
自闭症患者 2021-01-20 01:18

I train a model and save it using:

saver = tf.train.Saver()
saver.save(session, \'./my_model_name\')

Besides the checkpoint file,

2条回答
  •  猫巷女王i
    2021-01-20 01:54

    What your saver creates is called "Checkpoint V2" and was introduced in TF 0.12.

    I got it working quite nicely (though the docs on the C++ part are horrible, so it took me a day to solve). Some people suggest converting all variables to constants or freezing the graph, but none of these is actually needed.

    Python part (saving)

    with tf.Session() as sess:
        tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
    

    If you create the Saver with tf.trainable_variables(), you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver, just make sure you're creating the Saver after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.

    C++ part (inference)

    Note that checkpointPath isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.

    #include 
    #include 
    
    using namespace std;
    using namespace tensorflow;
    
    ...
    // set up your input paths
    const string pathToGraph = "models/my-model.meta"
    const string checkpointPath = "models/my-model";
    ...
    
    auto session = NewSession(SessionOptions());
    if (session == nullptr) {
        throw runtime_error("Could not create Tensorflow session.");
    }
    
    Status status;
    
    // Read in the protobuf graph we exported
    MetaGraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
    if (!status.ok()) {
        throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
    }
    
    // Add the graph to the session
    status = session->Create(graph_def.graph_def());
    if (!status.ok()) {
        throw runtime_error("Error creating graph: " + status.ToString());
    }
    
    // Read weights from the saved checkpoint
    Tensor checkpointPathTensor(DT_STRING, TensorShape());
    checkpointPathTensor.scalar()() = checkpointPath;
    status = session->Run(
            {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
            {},
            {graph_def.saver_def().restore_op_name()},
            nullptr);
    if (!status.ok()) {
        throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
    }
    
    // and run the inference to your liking
    auto feedDict = ...
    auto outputOps = ...
    std::vector outputTensors;
    status = session->Run(feedDict, outputOps, {}, &outputTensors);
    

    For completeness, here's the Python equivalent:

    Inference in Python

    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('models/my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('models/'))
        outputTensors = sess.run(outputOps, feed_dict=feedDict)
    

提交回复
热议问题