How to initialize weights in PyTorch?

后端 未结 9 1942
暗喜
暗喜 2020-11-28 01:10

How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?

9条回答
  •  悲&欢浪女
    2020-11-28 01:30

    Single layer

    To initialize the weights of a single layer, use a function from torch.nn.init. For instance:

    conv1 = torch.nn.Conv2d(...)
    torch.nn.init.xavier_uniform(conv1.weight)
    

    Alternatively, you can modify the parameters by writing to conv1.weight.data (which is a torch.Tensor). Example:

    conv1.weight.data.fill_(0.01)
    

    The same applies for biases:

    conv1.bias.data.fill_(0.01)
    

    nn.Sequential or custom nn.Module

    Pass an initialization function to torch.nn.Module.apply. It will initialize the weights in the entire nn.Module recursively.

    apply(fn): Applies fn recursively to every submodule (as returned by .children()) as well as self. Typical use includes initializing the parameters of a model (see also torch-nn-init).

    Example:

    def init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.01)
    
    net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
    net.apply(init_weights)
    

提交回复
热议问题