ニューラルネットワークでの知識の抽出

これは、論文「ニューラルネットワークにおける知識の抽出のPyTorch実装/チュートリアルです

これは、トレーニング済みの大規模なネットワークの知識を使用して小規模ネットワークをトレーニングする方法です。つまり、大規模なネットワークから知識を抽出する方法です。

データやラベルで直接トレーニングした場合、正則化を行った大規模なモデルや (ドロップアウトを使用した) モデルのアンサンブルは、小さなモデルよりも一般化が容易です。ただし、小さいモデルでも、大きなモデルの助けを借りてより一般化しやすいようにトレーニングできます。本番環境では、モデルが小さいほど速く、処理能力が少なく、メモリも少なくて済みます。

トレーニング済みモデルの出力確率は、誤ったクラスにもゼロ以外の確率を割り当てるため、ラベルよりも多くの情報を提供します。これらの確率から、サンプルが特定のクラスに属している可能性があることがわかります。たとえば、数字を分類する際、7 桁の画像が与えられた場合、一般化モデルでは 7 には高い確率、2 には小さいながらもゼロではない確率が与えられ、他の数字にはほぼゼロの確率を割り当てます。蒸留では、この情報を利用して小型モデルの学習効果を高めます

ソフトターゲット

確率は通常、ソフトマックス演算で計算されます。

ここではクラスの確率で、はロジットです。

出力確率分布と大規模ネットワークの出力確率分布 (ソフトターゲット) の間のクロスエントロピーまたは KL ダイバージェンスを最小化するようにスモールモデルをトレーニングします。

ここでの問題の 1 つは、大規模なネットワークによって誤ったクラスに割り当てられる確率がたいてい非常に小さく、損失の原因にならないことです。そこで、温度を当てることで確率を弱めます

値が大きいほど、確率は低くなります。

トレーニング

論文では、小型モデルをトレーニングする際に実際のラベルを予測するために 2 つ目の損失項を追加することを提案しています。複合損失は、ソフトターゲットと実際のラベルという2つの損失項の加重合計として計算されます。

蒸留用のデータセットはトランスファーセットと呼ばれ、論文では同じトレーニングデータを使用することを提案しています。

私たちの実験

CIFAR-10データセットでトレーニングします。ドロップアウトのあるパラメータを持つ大規模なモデルをトレーニングすると、検証セットの精度が 85% になります。パラメータを含む小さなモデルでは、80%の精度が得られます。

次に、大きいモデルから蒸留して小さなモデルにトレーニングを行うと、精度は 82%、精度は 2% 向上します。

72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10Configs

コンフィギュレーション

これを拡張してCIFAR10Configs 、データセットに関連するすべての構成、オプティマイザー、およびトレーニングループを定義します。

84class Configs(CIFAR10Configs):

スモールモデル

92    model: SmallModel

ラージモデル

94    large: LargeModel

ソフトターゲットのKLダイバージェンスロス

96    kl_div_loss = nn.KLDivLoss(log_target=True)

クロスエントロピー損失による真のラベルロス

98    loss_func = nn.CrossEntropyLoss()

温度、

100    temperature: float = 5.

ソフトターゲットの重量が減ります。

ソフトターゲットによって生成されるグラデーションは、次の方法でスケーリングされます。これを補うために、この論文ではソフトターゲットの損失を次の倍にスケーリングすることを提案しています

106    soft_targets_weight: float = 100.

真のラベルクロスエントロピー損失を実現する重量

108    label_loss_weight: float = 0.5

トレーニング/検証ステップ

蒸留を含むカスタムトレーニング/検証ステップを定義します

110    def step(self, batch: any, batch_idx: BatchIndex):

小型モデルのトレーニング/評価モード

118        self.model.train(self.mode.is_train)

評価モードの大型モデル

120        self.large.eval()

データをデバイスに移動

123        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモード時にグローバルステップ (処理されたサンプル数) を更新

126        if self.mode.is_train:
127            tracker.add_global_step(len(data))

ラージモデルから出力ロジットを取得

130        with torch.no_grad():
131            large_logits = self.large(data)

小さいモデルから出力ロジットを取得

134        output = self.model(data)

ソフトターゲット

138        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)

小型モデルの温度調整済み確率

141        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)

ソフトターゲットの損失の計算

144        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)

実際のラベルロスを計算

146        label_loss = self.loss_func(output, target)

2 つの損失の加重和

148        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

損失を記録する

150        tracker.add({"loss.kl_div.": soft_targets_loss,
151                     "loss.nll": label_loss,
152                     "loss.": loss})

精度の計算と記録

155        self.accuracy(output, target)
156        self.accuracy.track()

モデルのトレーニング

159        if self.mode.is_train:

勾配の計算

161            loss.backward()

最適化の一歩を踏み出す

163            self.optimizer.step()

各エポックの最後のバッチでモデルパラメータと勾配を記録します

165            if batch_idx.is_last:
166                tracker.add('model', self.model)

グラデーションをクリア

168            self.optimizer.zero_grad()

追跡したメトリクスを保存する

171        tracker.save()

大きいモデルを作成

174@option(Configs.large)
175def _large_model(c: Configs):
179    return LargeModel().to(c.device)

小型モデルを作成

182@option(Configs.model)
183def _small_student_model(c: Configs):
187    return SmallModel().to(c.device)
190def get_saved_model(run_uuid: str, checkpoint: int):
195    from labml_nn.distillation.large import Configs as LargeConfigs

評価モード (記録なし)

198    experiment.evaluate()

ラージモデルトレーニング実験のコンフィグを初期化

200    conf = LargeConfigs()

保存した設定を読み込む

202    experiment.configs(conf, experiment.load_configs(run_uuid))

保存/読み込み用のモデルを設定

204    experiment.add_pytorch_models({'model': conf.model})

どのランとチェックポイントをロードするかを設定

206    experiment.load(run_uuid, checkpoint)

実験を開始します。これでモデルが読み込まれ、すべての準備が整います

208    experiment.start()

モデルを返却する

211    return conf.model

蒸留による小型モデルのトレーニング

214def main(run_uuid: str, checkpoint: int):

保存したモデルをロード

219    large_model = get_saved_model(run_uuid, checkpoint)

実験を作成

221    experiment.create(name='distillation', comment='cifar10')

構成の作成

223    conf = Configs()

読み込んだラージモデルを設定する

225    conf.large = large_model

構成をロード

227    experiment.configs(conf, {
228        'optimizer.optimizer': 'Adam',
229        'optimizer.learning_rate': 2.5e-4,
230        'model': '_small_student_model',
231    })

保存/読み込み用のモデルを設定

233    experiment.add_pytorch_models({'model': conf.model})

実験をゼロから始める

235    experiment.load(None, None)

実験を開始し、トレーニングループを実行します

237    with experiment.start():
238        conf.run()

242if __name__ == '__main__':
243    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)