问题
I am trying to do research on batch normalization, and had to make some modifications for the pytorch BN code. I dig into the pytorch code and got stuck with torch.nn.functional.batch_norm, which references torch.batch_norm.
The problem is that torch.batch_norm cannot be further found in the torch library. Is there any way I can find the source code of this built-in function and re-implement it? Thanks!
回答1:
It's there, but it's not defined in Python. They're defined in C++ in the aten/ directories.
For CPU, the implementation (one of them, it depends on whether or not the input is contiguous) is here: https://github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization.cpp#L61-L126
For CUDA, the implementation is here: https://github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143
来源:https://stackoverflow.com/questions/58193798/how-to-find-built-in-function-source-code-in-pytorch