What is the definition of a non-trainable parameter?

后端 未结 4 1197
北荒
北荒 2020-12-13 18:51

What is the definition of non-trainable parameter in a model?

For example, while you are building your own model, its value is 0 as a default, but

4条回答
  •  余生分开走
    2020-12-13 19:34

    There are some details that other answers do not cover.

    In Keras, non-trainable parameters are the ones that are not trained using gradient descent. This is also controlled by the trainable parameter in each layer, for example:

    from keras.layers import *
    from keras.models import *
    model = Sequential()
    model.add(Dense(10, trainable=False, input_shape=(100,)))
    model.summary()
    

    This prints zero trainable parameters, and 1010 non-trainable parameters.

    _________________________________________________________________    
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_1 (Dense)              (None, 10)                1010      
    =================================================================
    Total params: 1,010
    Trainable params: 0
    Non-trainable params: 1,010
    _________________________________________________________________
    

    Now if you set the layer as trainable with model.layers[0].trainable = True then it prints:

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_1 (Dense)              (None, 10)                1010      
    =================================================================
    Total params: 1,010
    Trainable params: 1,010
    Non-trainable params: 0
    _________________________________________________________________
    

    Now all parameters are trainable and there are zero non-trainable parameters. But there are also layers that have both trainable and non-trainable parameters, one example is the BatchNormalization layer, where the mean and standard deviation of the activations is stored for use while test time. One example:

    model.add(BatchNormalization())
    model.summary()
    
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_1 (Dense)              (None, 10)                1010      
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 10)                40        
    =================================================================
    Total params: 1,050
    Trainable params: 1,030
    Non-trainable params: 20
    _________________________________________________________________
    

    This specific case of BatchNormalization has 40 parameters in total, 20 trainable, and 20 non-trainable. The 20 non-trainable parameters correspond to the computed mean and standard deviation of the activations that is used during test time, and these parameters will never be trainable using gradient descent, and are not affected by the trainable flag.

提交回复
热议问题