How can I convert a trained Tensorflow model to Keras?

匿名 (未验证) 提交于 2019-12-03 02:56:01

问题:

I have a trained Tensorflow model and weights vector which have been exported to protobuf and weights files respectively.

How can I convert these to JSON or YAML and HDF5 files which can be used by Keras?

I have the code for the Tensorflow model, so it would also be acceptable to convert the tf.Session to a keras model and save that in code.

回答1:

Currently, there is no direct in-built support in Tensorflow or Keras to convert the frozen model or the checkpoint file to hdf5 format.

But since you have mentioned that you have the code of Tensorflow model, you will have to rewrite that model's code in Keras. Then, you will have to read the values of your variables from the checkpoint file and assign it to Keras model using layer.load_weights(weights) method.

More than this methodology, I would suggest to you to do the training directly in Keras as it claimed that Keras' optimizers are 5-10% times faster than Tensorflow's optimizers. Other way is to write your code in Tensorflow with tf.contrib.keras module and save the file directly in hdf5 format.



回答2:

I think the callback in keras is also a solution.

The ckpt file can be saved by TF with:

saver = tf.train.Saver() saver.save(sess, checkpoint_name) 

and to load checkpoint in Keras, you need a callback class as follow:

class MyCallbacks(keras.callbacks.Callback):     def __init__(self, pretrained_file):         self.pretrained_file = pretrained_file         self.sess = keras.backend.get_session()         self.saver = tf.train.Saver()     def on_train_begin(self, logs=None):         if self.pretrian_model_path:             self.saver.restore(self.sess, self.pretrian_model_path)             print('load weights: OK.') 

Then in your keras script:

 model.compile(loss='categorical_crossentropy', optimizer='rmsprop')  testCallBack = MyCallbacks(pretrian_model_path='./XXXX.ckpt')   model.fit(x_train, y_train, batch_size=128, epochs=20, callbacks=[testCallBack]) 

That will be fine. I think it is easy to implement and hope it helps.



回答3:

Unsure if this is what you are looking for, but I happened to just do the same with the newly released keras support in TF 1.2. You can find more on the API here: https://www.tensorflow.org/api_docs/python/tf/contrib/keras

To save you a little time, I also found that I had to include keras modules as shown below with the additional python.keras appended to what is shown in the API docs.

from tensorflow.contrib.keras.python.keras.models import Sequential

Hope that helps get you where you want to go. Essentially once integrated in, you then just handle your model/weight export as usual.



易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!