可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
I have a trained Tensorflow model and weights vector which have been exported to protobuf and weights files respectively.
How can I convert these to JSON or YAML and HDF5 files which can be used by Keras?
I have the code for the Tensorflow model, so it would also be acceptable to convert the tf.Session to a keras model and save that in code.
回答1:
Currently, there is no direct in-built support in Tensorflow or Keras to convert the frozen model or the checkpoint file to hdf5 format.
But since you have mentioned that you have the code of Tensorflow model, you will have to rewrite that model's code in Keras. Then, you will have to read the values of your variables from the checkpoint file and assign it to Keras model using layer.load_weights(weights) method.
More than this methodology, I would suggest to you to do the training directly in Keras as it claimed that Keras' optimizers are 5-10% times faster than Tensorflow's optimizers. Other way is to write your code in Tensorflow with tf.contrib.keras module and save the file directly in hdf5 format.
回答2:
I think the callback in keras is also a solution.
The ckpt file can be saved by TF with:
saver = tf.train.Saver() saver.save(sess, checkpoint_name)
and to load checkpoint in Keras, you need a callback class as follow:
class MyCallbacks(keras.callbacks.Callback): def __init__(self, pretrained_file): self.pretrained_file = pretrained_file self.sess = keras.backend.get_session() self.saver = tf.train.Saver() def on_train_begin(self, logs=None): if self.pretrian_model_path: self.saver.restore(self.sess, self.pretrian_model_path) print('load weights: OK.')
Then in your keras script:
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') testCallBack = MyCallbacks(pretrian_model_path='./XXXX.ckpt') model.fit(x_train, y_train, batch_size=128, epochs=20, callbacks=[testCallBack])
That will be fine. I think it is easy to implement and hope it helps.
回答3:
Unsure if this is what you are looking for, but I happened to just do the same with the newly released keras support in TF 1.2. You can find more on the API here: https://www.tensorflow.org/api_docs/python/tf/contrib/keras
To save you a little time, I also found that I had to include keras modules as shown below with the additional python.keras appended to what is shown in the API docs.
from tensorflow.contrib.keras.python.keras.models import Sequential
Hope that helps get you where you want to go. Essentially once integrated in, you then just handle your model/weight export as usual.