PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx
在写 PyTorch 代码时,我们会发现一些功能重复的操作,比如卷积、激活、池化等操作。这些操作分别可以通过 torch.nn.xxx 和 torch.nn.functional.xxx 来实现。 首先可以观察源码: eg:torch.nn.Conv2d CLASS torch.nn.Conv2d( in_channels , out_channels , kernel_size , stride=1 , padding=0 , dilation=1 , groups=1 , bias=True , padding_mode='zeros' ) eg:torch.nn.functional torch.nn.functional.conv2d( input , weight , bias=None , stride=1 , padding=0 , dilation=1 , groups=1 ) → Tensor 从中,我们可以发现,nn.Conv2d 是一个类,而 nn.functional.conv2d是一个函数。 换言之: nn.Module 实现的 layer 是由 class Layer(nn.Module) 定义的特殊类 nn.functional 中的函数更像是纯函数,由 def function(input) 定义 此外: 两者的调用方式不同:调用 nn.xxx