Reset all weights of Keras model

元气小坏坏 提交于 2021-01-01 04:16:30

问题


I would like to be able to reset the weights of my entire Keras model so that I do not have to compile it again. Compiling the model is currently the main bottleneck of my code. Here is an example of what I mean:

import tensorflow as tf  

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
   
data = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = data.load_data()

model.fit(x=x_train, y=y_train, epochs=10)

# Reset all weights of model here
# model.reset_all_weights() <----- something like that

model.fit(x=x_train, y=y_train, epochs=10)

回答1:


You can use this loop:

for ix, layer in enumerate(model.layers):
    if hasattr(model.layers[ix], 'kernel_initializer') and \
            hasattr(model.layers[ix], 'bias_initializer'):
        weight_initializer = model.layers[ix].kernel_initializer
        bias_initializer = model.layers[ix].bias_initializer

        old_weights, old_biases = model.layers[ix].get_weights()

        model.layers[ix].set_weights([
            weight_initializer(shape=old_weights.shape),
            bias_initializer(shape=len(old_biases))])

Original weights:

model.layers[1].get_weights()[0][0]
array([ 0.4450057 , -0.13564804,  0.35884023,  0.41411972,  0.24866664,
        0.07641453,  0.45726687, -0.04410008,  0.33194816, -0.1965386 ,
       -0.38438258, -0.13263905, -0.23807487,  0.40130925, -0.07339832,
        0.20535922], dtype=float32)

New weights:

model.layers[1].get_weights()[0][0]
array([-0.4607593 , -0.13104361, -0.0372932 , -0.34242013,  0.12066692,
       -0.39146423,  0.3247317 ,  0.2635846 , -0.10496247, -0.40134245,
        0.19276887,  0.2652442 , -0.18802321, -0.18488845,  0.0826562 ,
       -0.23322225], dtype=float32)



来源:https://stackoverflow.com/questions/63435679/reset-all-weights-of-keras-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!