在神经网络中提炼知识

这是论文《在神经网络中提炼知识》的 PyT orch 实现/教程。

这是一种使用经过训练的大型网络中的知识来训练小型网络的方法;即从大型网络中提炼知识。

直接在数据和标签上训练时,具有正则化或模型集合(使用 dropout)的大型模型比小型模型的概化效果更好。但是,在大型模型的帮助下,可以训练小模型以更好地进行概括。较小的模型在生产中更好:速度更快、计算更少、内存更少。

经过训练的模型的输出概率比标签提供的信息更多,因为它也会为错误的类分配非零概率。这些概率告诉我们,样本有可能属于某些类别。例如,在对数字进行分类时,当给定数字 7 的图像时,广义模型会给出7的高概率,给2的概率很小但不是零,而给其他数字分配几乎为零的概率。蒸馏利用这些信息来更好地训练小型模型。

软目标

概率通常是使用 softmax 运算计算的,

其中,是类的概率是对数。

我们训练小型模型以最大限度地减少其输出概率分布和大型网络的输出概率分布(软目标)之间的交叉熵或 KL 差异。

这里的问题之一是,大型网络分配给错误类别的概率通常很小,不会导致损失。所以他们通过施加温度来软化概率

其中,的值越高,产生的概率越低。

训练

论文建议在训练小型模型时添加第二个损失项来预测实际标签。我们将综合损失计算为两个损失项的加权总和:软目标和实际标签。

用于蒸馏的数据集称为传输集,该论文建议使用相同的训练数据。

我们的实验

我们在 CIFAR-10 数据集上训练。我们训练了一个大型模型,该模型参数带有 dropout,它在验证集上的准确率为 85%。带有参数的小型模型的准确度为80%。

然后,我们使用大型模型的蒸馏法训练小型模型,其精度为82%;精度提高了2%。

72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10Configs

配置

从此扩展定义CIFAR10Configs 了所有与数据集相关的配置、优化器和训练循环。

84class Configs(CIFAR10Configs):

小模型

92    model: SmallModel

大型模型

94    large: LargeModel

软目标的 KL 分散损失

96    kl_div_loss = nn.KLDivLoss(log_target=True)

真实标签丢失的交叉熵损失

98    loss_func = nn.CrossEntropyLoss()

温度,

100    temperature: float = 5.

软目标损失的权重。

软目标产生的梯度会被缩放。为了弥补这一点,本文建议将软目标的损失缩小一倍

106    soft_targets_weight: float = 100.

真实标签交叉熵损失的权重

108    label_loss_weight: float = 0.5

培训/验证步骤

我们定义了一个定制的训练/验证步骤,包括蒸馏

110    def step(self, batch: any, batch_idx: BatchIndex):

小模型的训练/评估模式

118        self.model.train(self.mode.is_train)

评估模式中的大型模型

120        self.large.eval()

将数据移动到设备

123        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的样本数)

126        if self.mode.is_train:
127            tracker.add_global_step(len(data))

从大型模型中获取输出 logit

130        with torch.no_grad():
131            large_logits = self.large(data)

从小型模型中获取输出 logits

134        output = self.model(data)

软目标

138        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)

小模型的温度调整概率

141        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)

计算软目标损失

144        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)

计算真实标签丢失

146        label_loss = self.loss_func(output, target)

两次亏损的加权总和

148        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

记录损失

150        tracker.add({"loss.kl_div.": soft_targets_loss,
151                     "loss.nll": label_loss,
152                     "loss.": loss})

计算和记录精度

155        self.accuracy(output, target)
156        self.accuracy.track()

训练模型

159        if self.mode.is_train:

计算梯度

161            loss.backward()

采取优化器步骤

163            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

165            if batch_idx.is_last:
166                tracker.add('model', self.model)

清除渐变

168            self.optimizer.zero_grad()

保存跟踪的指标

171        tracker.save()

创建大型模型

174@option(Configs.large)
175def _large_model(c: Configs):
179    return LargeModel().to(c.device)

创建小模型

182@option(Configs.model)
183def _small_student_model(c: Configs):
187    return SmallModel().to(c.device)
190def get_saved_model(run_uuid: str, checkpoint: int):
195    from labml_nn.distillation.large import Configs as LargeConfigs

在评估模式下(无录音)

198    experiment.evaluate()

初始化大型模型训练实验的配置

200    conf = LargeConfigs()

加载保存的配置

202    experiment.configs(conf, experiment.load_configs(run_uuid))

设置用于保存/加载的模型

204    experiment.add_pytorch_models({'model': conf.model})

设置要加载的运行和检查点

206    experiment.load(run_uuid, checkpoint)

开始实验-这将加载模型,并准备所有内容

208    experiment.start()

返回模型

211    return conf.model

使用蒸馏训练小型模型

214def main(run_uuid: str, checkpoint: int):

加载已保存的模型

219    large_model = get_saved_model(run_uuid, checkpoint)

创建实验

221    experiment.create(name='distillation', comment='cifar10')

创建配置

223    conf = Configs()

设置加载的大型模型

225    conf.large = large_model

装载配置

227    experiment.configs(conf, {
228        'optimizer.optimizer': 'Adam',
229        'optimizer.learning_rate': 2.5e-4,
230        'model': '_small_student_model',
231    })

设置保存/加载的模型

233    experiment.add_pytorch_models({'model': conf.model})

从头开始实验

235    experiment.load(None, None)

开始实验并运行训练循环

237    with experiment.start():
238        conf.run()

242if __name__ == '__main__':
243    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)