How can I convert a trained Tensorflow model to Keras?

前端 未结 4 1594
甜味超标
甜味超标 2020-12-13 06:43

I have a trained Tensorflow model and weights vector which have been exported to protobuf and weights files respectively.

How can I convert these to JSON or YAML and

4条回答
  •  醉话见心
    2020-12-13 07:28

    Francois Chollet, the creator of keras, stated in 04/2017 "you cannot turn an arbitrary TensorFlow checkpoint into a Keras model. What you can do, however, is build an equivalent Keras model then load into this Keras model the weights" , see https://github.com/keras-team/keras/issues/5273 . To my knowledge this hasn't changed.

    A small example:

    First, you can extract the weights of a tensorflow checkpoint like this

    PATH_REL_META = r'checkpoint1.meta'
        
    # start tensorflow session
    with tf.Session() as sess:
        
        # import graph
        saver = tf.train.import_meta_graph(PATH_REL_META)
        
        # load weights for graph
        saver.restore(sess, PATH_REL_META[:-5])
            
        # get all global variables (including model variables)
        vars_global = tf.global_variables()
        
        # get their name and value and put them into dictionary
        sess.as_default()
        model_vars = {}
        for var in vars_global:
            try:
                model_vars[var.name] = var.eval()
            except:
                print("For var={}, an exception occurred".format(var.name))
    

    It might also be of use to export the tensorflow model for use in tensorboard, see https://stackoverflow.com/a/43569991/2135504

    Second, you build you keras model as usually and finalize it by "model.compile". Pay attention that you need to give you define each layer by name and add it to the model after that, e.g.

    layer_1 = keras.layers.Conv2D(6, (7,7), activation='relu', input_shape=(48,48,1))
    net.add(layer_1)
    ...
    net.compile(...)
    

    Third, you can set the weights with the tensorflow values, e.g.

    layer_1.set_weights([model_vars['conv7x7x1_1/kernel:0'], model_vars['conv7x7x1_1/bias:0']])
    

提交回复
热议问题