What does model.eval() do in pytorch?

后端 未结 2 1689
轮回少年
轮回少年 2020-12-24 11:55

I am using this code, and saw model.eval() in some cases.

I understand it is supposed to allow me to "evaluate my model", but I don\'t understa

相关标签:
2条回答
  • 2020-12-24 12:28

    model.eval is a method of torch.nn.Module

    The opposite method is the model.train explained nice by Umang Gupta.

    0 讨论(0)
  • 2020-12-24 12:36

    model.eval() is a kind of switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval() will do it for you. In addition, the common practice for evaluating/validation is using torch.no_grad() in pair with model.eval() to turn off gradients computation:

    # evaluate model:
    model.eval()
    
    with torch.no_grad():
        ...
        out_data = model(data)
        ...
    

    BUT, don't forget to turn back to training mode after eval step:

    # training step
    ...
    model.train()
    ...
    
    0 讨论(0)
提交回复
热议问题