将Loss视作单独的层,在forward函数里写明loss的计算方式,无需定义backward
class MyLoss(nn.Module): def __init__(self): super(MyLoss, self).__init__() print '1' def forward(self, pred, truth): return torch.mean(torch.mean((pred-truth)**2,1),0)
super(MyLoss, self).__init__()
AttributeError: 'MyLoss' object has no attribute '_forward_pre_hooks' 错误。
要打印loss可以用
loss.data.cpu().numpy()[0]访问
文章来源: Pytorch如何自定义Loss