Adding L1/L2 regularization in PyTorch?

后端 未结 5 2105
南旧
南旧 2020-12-24 01:20

Is there any way, I can add simple L1/L2 regularization in PyTorch? We can probably compute the regularized loss by simply adding the data_loss with the r

5条回答
  •  攒了一身酷
    2020-12-24 01:47

    Interesting torch.norm is slower on CPU and faster on GPU vs. direct approach.

    import torch
    x = torch.randn(1024,100)
    y = torch.randn(1024,100)
    
    %timeit torch.sqrt((x - y).pow(2).sum(1))
    %timeit torch.norm(x - y, 2, 1)
    

    Out:

    1000 loops, best of 3: 910 µs per loop
    1000 loops, best of 3: 1.76 ms per loop
    

    On the other hand:

    import torch
    x = torch.randn(1024,100).cuda()
    y = torch.randn(1024,100).cuda()
    
    %timeit torch.sqrt((x - y).pow(2).sum(1))
    %timeit torch.norm(x - y, 2, 1)
    

    Out:

    10000 loops, best of 3: 50 µs per loop
    10000 loops, best of 3: 26 µs per loop
    

提交回复
热议问题