Why do we need to call zero_grad() in PyTorch?

后端 未结 2 1852
清歌不尽
清歌不尽 2020-12-02 05:03

The method zero_grad() needs to be called during training. But the documentation is not very helpful

|  zero_grad(self)
|      Sets gradients of         


        
2条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2020-12-02 05:11

    zero_grad() is restart looping without losses from last step if you use the gradient method for decreasing the error (or losses)

    if you don't use zero_grad() the loss will be decrease not increase as require

    for example if you use zero_grad() you will find following output :

    model training loss is 1.5
    model training loss is 1.4
    model training loss is 1.3
    model training loss is 1.2
    

    if you don't use zero_grad() you will find following output :

    model training loss is 1.4
    model training loss is 1.9
    model training loss is 2
    model training loss is 2.8
    model training loss is 3.5
    

提交回复
热议问题