Understanding torch.nn.Parameter

后端 未结 2 1676
广开言路
广开言路 2020-12-04 09:59

I am new to pytorch and I have difficulty in understanding how torch.nn.Parameter() works.

I have gone through the documentation in https://pytorch.org/

2条回答
  •  盖世英雄少女心
    2020-12-04 10:19

    Recent PyTorch releases just have Tensors, it came out the concept of the Variable has deprecated.

    Parameters are just Tensors limited to the module they are defined (in the module constructor __init__ method).

    They will appear inside module.parameters(). This comes handy when you build your custom modules, that learn thanks to these parameters gradient descent.

    Anything that is true for the PyTorch tensors is true for parameters, since they are tensors.

    Additionally, if module goes to GPU, parameters goes as well. If module is saved parameters will also be saved.

    There is a similar concept to model parameters called buffers.

    These are named tensors inside the module, but these tensors are not meant to learn via gradient descent, instead you can think these are like variables. You will update your named buffers inside module forward() as you like.

    For buffers, it is also true they will go to GPU with the module, and they will be saved together with the module.

提交回复
热议问题