How to load only specific weights on Keras

南楼画角 提交于 2019-12-08 22:55:45

问题


I have a trained model that I've exported the weights and want to partially load into another model. My model is built in Keras using TensorFlow as backend.

Right now I'm doing as follows:

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape, trainable=False))
model.add(Activation('relu', trainable=False))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3), trainable=False))
model.add(Activation('relu', trainable=False))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3), trainable=True))
model.add(Activation('relu', trainable=True))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])


model.load_weights("image_500.h5")
model.pop()
model.pop()
model.pop()
model.pop()
model.pop()
model.pop()


model.add(Conv2D(1, (6, 6),strides=(1, 1), trainable=True))
model.add(Activation('relu', trainable=True))

model.compile(loss='binary_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

I'm sure it's a terrible way to do it, although it works.

How do I load just the first 9 layers?


回答1:


If your first 9 layers are consistently named between your original trained model and the new model, then you can use model.load_weights() with by_name=True. This will update weights only in the layers of your new model that have an identically named layer found in the original trained model.

The name of the layer can be specified with the name keyword, for example:

model.add(Dense(8, activation='relu',name='dens_1'))



回答2:


This call:

weights_list = model.get_weights()

will return a list of all weight tensors in the model, as Numpy arrays.

All what you have to do next is to iterate over this list and apply:

for i, weights in enumerate(weights_list[0:9]):
    model.layers[i].set_weights(weights)

where model.layers is a flattened list of the layers comprising the model. In this case, you reload the weights of the first 9 layers.

More information is available here:

https://keras.io/layers/about-keras-layers/

https://keras.io/models/about-keras-models/



来源:https://stackoverflow.com/questions/43702323/how-to-load-only-specific-weights-on-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!