What is the difference between register_parameter and register_buffer in PyTorch?

前端 未结 2 2027
感动是毒
感动是毒 2020-12-30 06:12

Module\'s parameters get changed during training, that is, they are what is learnt during training of a neural network, but what is a buffer?

and is it learnt during

2条回答
  •  旧巷少年郎
    2020-12-30 06:54

    Pytorch doc for register_buffer() method reads

    This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the persistent state.

    As you already observed, model parameters are learned and updated using SGD during the training process.
    However, sometimes there are other quantities that are part of a model's "state" and should be
    - saved as part of state_dict.
    - moved to cuda() or cpu() with the rest of the model's parameters.
    - cast to float/half/double with the rest of the model's parameters.
    Registering these "arguments" as the model's buffer allows pytorch to track them and save them like regular parameters, but prevents pytorch from updating them using SGD mechanism.

    An example for a buffer can be found in _BatchNorm module where the running_mean , running_var and num_batches_tracked are registered as buffers and updated by accumulating statistics of data forwarded through the layer. This is in contrast to weight and bias parameters that learns an affine transformation of the data using regular SGD optimization.

提交回复
热议问题