What does “model.trainable = False” mean in Keras?

烈酒焚心 提交于 2020-05-29 05:23:55

问题


I want to freeze a pre-trained network in Keras. I found base.trainable = False in the documentation. But I didn't understand how it works. With len(model.trainable_weights) I found out that I have 30 trainable weights. How can that be? The network shows total trainable params: 16,812,353. After freezing I have 4 trainable weights. Maybe I don't understand the difference between params and weights. Unfortunately I am a beginner in Deep Learning. Maybe someone can help me.


回答1:


A Keras Model is trainable by default - you have two means of freezing all the weights:

  1. model.trainable = False before compiling the model
  2. for layer in model.layers: layer.trainable = False - works before & after compiling

(1) must be done before compilation since Keras treats model.trainable as a boolean flag at compiling, and performs (2) under the hood. After doing either of the above, you should see:

print(model.trainable_weights)
# [] 

Regarding the docs, likely outdated - see linked source code above, up-to-date.



来源:https://stackoverflow.com/questions/58224816/what-does-model-trainable-false-mean-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!