How to initialize weights in PyTorch?

后端 未结 9 1924
暗喜
暗喜 2020-11-28 01:10

How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?

9条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2020-11-28 01:23

    If you want some extra flexibility, you can also set the weights manually.

    Say you have input of all ones:

    import torch
    import torch.nn as nn
    
    input = torch.ones((8, 8))
    print(input)
    
    tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.]])
    

    And you want to make a dense layer with no bias (so we can visualize):

    d = nn.Linear(8, 8, bias=False)
    

    Set all the weights to 0.5 (or anything else):

    d.weight.data = torch.full((8, 8), 0.5)
    print(d.weight.data)
    

    The weights:

    Out[14]: 
    tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])
    

    All your weights are now 0.5. Pass the data through:

    d(input)
    
    Out[13]: 
    tensor([[4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.],
            [4., 4., 4., 4., 4., 4., 4., 4.]], grad_fn=)
    

    Remember that each neuron receives 8 inputs, all of which have weight 0.5 and value of 1 (and no bias), so it sums up to 4 for each.

提交回复
热议问题