How to find built-in function source code in pytorch

妖精的绣舞 提交于 2020-12-11 07:57:33

问题


I am trying to do research on batch normalization, and had to make some modifications for the pytorch BN code. I dig into the pytorch code and got stuck with torch.nn.functional.batch_norm, which references torch.batch_norm.

The problem is that torch.batch_norm cannot be further found in the torch library. Is there any way I can find the source code of this built-in function and re-implement it? Thanks!


回答1:


It's there, but it's not defined in Python. They're defined in C++ in the aten/ directories.

For CPU, the implementation (one of them, it depends on whether or not the input is contiguous) is here: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126

For CUDA, the implementation is here: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143



来源:https://stackoverflow.com/questions/58193798/how-to-find-built-in-function-source-code-in-pytorch

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!