这是论文《在神经网络中提炼知识》的 PyT orch 实现/教程。
这是一种使用经过训练的大型网络中的知识来训练小型网络的方法;即从大型网络中提炼知识。
直接在数据和标签上训练时,具有正则化或模型集合(使用 dropout)的大型模型比小型模型的概化效果更好。但是,在大型模型的帮助下,可以训练小模型以更好地进行概括。较小的模型在生产中更好:速度更快、计算更少、内存更少。
经过训练的模型的输出概率比标签提供的信息更多,因为它也会为错误的类分配非零概率。这些概率告诉我们,样本有可能属于某些类别。例如,在对数字进行分类时,当给定数字 7 的图像时,广义模型会给出7的高概率,给2的概率很小但不是零,而给其他数字分配几乎为零的概率。蒸馏利用这些信息来更好地训练小型模型。
概率通常是使用 softmax 运算计算的,
其中,是类的概率,是对数。
我们训练小型模型以最大限度地减少其输出概率分布和大型网络的输出概率分布(软目标)之间的交叉熵或 KL 差异。
这里的问题之一是,大型网络分配给错误类别的概率通常很小,不会导致损失。所以他们通过施加温度来软化概率,
其中,的值越高,产生的概率越低。
论文建议在训练小型模型时添加第二个损失项来预测实际标签。我们将综合损失计算为两个损失项的加权总和:软目标和实际标签。
用于蒸馏的数据集称为传输集,该论文建议使用相同的训练数据。
我们在 CIFAR-10 数据集上训练。我们训练了一个大型模型,该模型的参数带有 dropout,它在验证集上的准确率为 85%。带有参数的小型模型的准确度为80%。
然后,我们使用大型模型的蒸馏法训练小型模型,其精度为82%;精度提高了2%。
72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10Configs
84class Configs(CIFAR10Configs):
小模型
92 model: SmallModel
大型模型
94 large: LargeModel
软目标的 KL 分散损失
96 kl_div_loss = nn.KLDivLoss(log_target=True)
真实标签丢失的交叉熵损失
98 loss_func = nn.CrossEntropyLoss()
温度,
100 temperature: float = 5.
106 soft_targets_weight: float = 100.
真实标签交叉熵损失的权重
108 label_loss_weight: float = 0.5
110 def step(self, batch: any, batch_idx: BatchIndex):
小模型的训练/评估模式
118 self.model.train(self.mode.is_train)
评估模式中的大型模型
120 self.large.eval()
将数据移动到设备
123 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的样本数)
126 if self.mode.is_train:
127 tracker.add_global_step(len(data))
从大型模型中获取输出 logit、
130 with torch.no_grad():
131 large_logits = self.large(data)
从小型模型中获取输出 logits、
134 output = self.model(data)
软目标
138 soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)
小模型的温度调整概率
141 soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)
计算软目标损失
144 soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
计算真实标签丢失
146 label_loss = self.loss_func(output, target)
两次亏损的加权总和
148 loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss
记录损失
150 tracker.add({"loss.kl_div.": soft_targets_loss,
151 "loss.nll": label_loss,
152 "loss.": loss})
计算和记录精度
155 self.accuracy(output, target)
156 self.accuracy.track()
训练模型
159 if self.mode.is_train:
计算梯度
161 loss.backward()
采取优化器步骤
163 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
165 if batch_idx.is_last:
166 tracker.add('model', self.model)
清除渐变
168 self.optimizer.zero_grad()
保存跟踪的指标
171 tracker.save()
174@option(Configs.large)
175def _large_model(c: Configs):
179 return LargeModel().to(c.device)
182@option(Configs.model)
183def _small_student_model(c: Configs):
187 return SmallModel().to(c.device)
195 from labml_nn.distillation.large import Configs as LargeConfigs
在评估模式下(无录音)
198 experiment.evaluate()
初始化大型模型训练实验的配置
200 conf = LargeConfigs()
加载保存的配置
202 experiment.configs(conf, experiment.load_configs(run_uuid))
设置用于保存/加载的模型
204 experiment.add_pytorch_models({'model': conf.model})
设置要加载的运行和检查点
206 experiment.load(run_uuid, checkpoint)
开始实验-这将加载模型,并准备所有内容
208 experiment.start()
返回模型
211 return conf.model
使用蒸馏训练小型模型
214def main(run_uuid: str, checkpoint: int):
加载已保存的模型
219 large_model = get_saved_model(run_uuid, checkpoint)
创建实验
221 experiment.create(name='distillation', comment='cifar10')
创建配置
223 conf = Configs()
设置加载的大型模型
225 conf.large = large_model
装载配置
227 experiment.configs(conf, {
228 'optimizer.optimizer': 'Adam',
229 'optimizer.learning_rate': 2.5e-4,
230 'model': '_small_student_model',
231 })
设置保存/加载的模型
233 experiment.add_pytorch_models({'model': conf.model})
从头开始实验
235 experiment.load(None, None)
开始实验并运行训练循环
237 with experiment.start():
238 conf.run()
242if __name__ == '__main__':
243 main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)