模型蒸馏(Distil)及mnist实践
结论:蒸馏是个好方法。 模型压缩/蒸馏在论文《Model Compression》及《Distilling the Knowledge in a Neural Network》提及,下面介绍后者及使用keras测试mnist数据集。 蒸馏:使用小模型模拟大模型的泛性。 通常,我们训练mnist时,target是分类标签,在蒸馏模型时,使用的是教师模型的输出概率分布作为“soft target”。也即损失为学生网络与教师网络输出的交叉熵(这里采用DistilBert论文中的策略,此论文不同)。 当训练好教师网络后,我们可以不再需要分类标签,只需要比较2个网络的输出概率分布。当然可以在损失里再加上学生网络的分类损失,论文也提到可以进一步优化。 如图,将softmax公式稍微变换一下,目的是使得输出更小,softmax后就更为平滑。 论文的损失定义 本文代码使用的损失为p和q的交叉熵 代码测试部分 1,教师网络,测试精度99.46%,已经相当好了,可训练参数858,618。 # 教师网络 inputs=Input((28,28,1)) x=Conv2D(64,3)(inputs) x=BatchNormalization(center=True,scale=False)(x) x=Activation('relu')(x) x=Conv2D(64,3,strides=2)(x) x