Unexpected increase in validation error in MNIST Pytorch

蓝咒 提交于 2019-12-02 05:13:34

Long story short: you need to change item = self.X[idx] to item = self.X[idx].copy().

Long story long: T.ToTensor() runs torch.from_numpy, which returns a tensor which aliases the memory of your numpy array dataset.X. And T.Normalize() works inplace, so each time the sample is drawn it has mean subtracted and is divided by std, leading to degradation of your dataset.

Edit: regarding why it works in the original MNIST loader, the rabbit hole is even deeper. The key line in MNIST is that the image is transformed into a PIL.Image instance. The operation claims to only copy in case the buffer is not contiguous (it is in our case), but under the hood it checks whether it's strided instead (which it is), and thus copies it. So by luck, the default torchvision pipeline involves a copy and thus in-place operation of T.Normalize() does not corrupt the in-memory self.data of our MNIST instance.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!