How to initialize weights in PyTorch?

后端 未结 9 1929
暗喜
暗喜 2020-11-28 01:10

How to initialize the weights and biases (for example, with He or Xavier initialization) in a network in PyTorch?

9条回答
  •  隐瞒了意图╮
    2020-11-28 01:44

        import torch.nn as nn        
    
        # a simple network
        rand_net = nn.Sequential(nn.Linear(in_features, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, h_size),
                                 nn.BatchNorm1d(h_size),
                                 nn.ReLU(),
                                 nn.Linear(h_size, 1),
                                 nn.ReLU())
    
        # initialization function, first checks the module type,
        # then applies the desired changes to the weights
        def init_normal(m):
            if type(m) == nn.Linear:
                nn.init.uniform_(m.weight)
    
        # use the modules apply function to recursively apply the initialization
        rand_net.apply(init_normal)
    

提交回复
热议问题