动手学深度学习PyTorch版——Task06学习笔记
批量归一化和残差网络 批量归一化 从零开始 import time import torch from torch import nn , optim import torch . nn . functional as F import torchvision import sys sys . path . append ( "/home/kesci/input/" ) import d2lzh1981 as d2l device = torch . device ( 'cuda' if torch . cuda . is_available ( ) else 'cpu' ) def batch_norm ( is_training , X , gamma , beta , moving_mean , moving_var , eps , momentum ) : # 判断当前模式是训练模式还是预测模式 if not is_training : # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差 X_hat = ( X - moving_mean ) / torch . sqrt ( moving_var + eps ) else : assert len ( X . shape ) in ( 2 , 4 ) if len ( X . shape ) == 2 : #