What does model.eval() do in pytorch?

孤街醉人 提交于 2020-12-24 04:00:07

问题


I am using this code, and saw model.eval() in some cases.

I understand it is supposed to allow me to "evaluate my model", but I don't understand when I should and shouldn't use it, or how to turn if off.

I would like to run the above code to train the network, and also be able to run validation every epoch. I wasn't able to do it still.


回答1:


model.eval() is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval() will do it for you. In addition, the common practice for evaluating/validation is using torch.no_grad() in pair with model.eval() to turn off gradients computation:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

BUT, don't forget to turn back to training mode after eval step:

# training step
...
model.train()
...



回答2:


model.eval is a method of torch.nn.Module

The opposite method is the model.train explained nice by Umang Gupta.



来源:https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!