写给程序员的机器学习入门 (四)
这篇将会着重介绍使用 pytorch 进行机器学习训练过程中的一些常见技巧,掌握它们可以让你事半功倍。 使用的代码大部分会基于上一篇最后一个例子,即根据码农条件预测工资🙀,如果你没看上一篇请点击 这里 查看。 保存和读取模型状态 在 pytorch 中各种操作都是围绕 tensor 对象来的,模型的参数也是 tensor,如果我们把训练好的 tensor 保存到硬盘然后下次再从硬盘读取就可以直接使用了。 我们先来看看如何保存单个 tensor,以下代码运行在 python 的 REPL 中: # 引用 pytorch >>> import torch # 新建一个 tensor 对象 >>> a = torch.tensor([1, 2, 3], dtype=torch.float) # 保存 tensor 到文件 1.pt >>> torch.save(a, "1.pt") # 从文件 1.pt 读取 tensor >>> b = torch.load("1.pt") >>> b tensor([1., 2., 3.]) torch.save 保存 tensor 的时候会使用 python 的 pickle 格式,这个格式保证在不同的 python 版本间兼容,但不支持压缩内容,所以如果 tensor 非常大保存的文件将会占用很多空间,我们可以在保存前压缩