これは、論文「ニューラルネットワークにおける知識の抽出」のPyTorch実装/チュートリアルです。
これは、トレーニング済みの大規模なネットワークの知識を使用して小規模ネットワークをトレーニングする方法です。つまり、大規模なネットワークから知識を抽出する方法です。
データやラベルで直接トレーニングした場合、正則化を行った大規模なモデルや (ドロップアウトを使用した) モデルのアンサンブルは、小さなモデルよりも一般化が容易です。ただし、小さいモデルでも、大きなモデルの助けを借りてより一般化しやすいようにトレーニングできます。本番環境では、モデルが小さいほど速く、処理能力が少なく、メモリも少なくて済みます。
トレーニング済みモデルの出力確率は、誤ったクラスにもゼロ以外の確率を割り当てるため、ラベルよりも多くの情報を提供します。これらの確率から、サンプルが特定のクラスに属している可能性があることがわかります。たとえば、数字を分類する際、7 桁の画像が与えられた場合、一般化モデルでは 7 には高い確率、2 には小さいながらもゼロではない確率が与えられ、他の数字にはほぼゼロの確率を割り当てます。蒸留では、この情報を利用して小型モデルの学習効果を高めます
。確率は通常、ソフトマックス演算で計算されます。
ここで、はクラスの確率で、はロジットです。
出力確率分布と大規模ネットワークの出力確率分布 (ソフトターゲット) の間のクロスエントロピーまたは KL ダイバージェンスを最小化するようにスモールモデルをトレーニングします。
ここでの問題の 1 つは、大規模なネットワークによって誤ったクラスに割り当てられる確率がたいてい非常に小さく、損失の原因にならないことです。そこで、温度を当てることで確率を弱めます
値が大きいほど、確率は低くなります。
論文では、小型モデルをトレーニングする際に実際のラベルを予測するために 2 つ目の損失項を追加することを提案しています。複合損失は、ソフトターゲットと実際のラベルという2つの損失項の加重合計として計算されます。
蒸留用のデータセットはトランスファーセットと呼ばれ、論文では同じトレーニングデータを使用することを提案しています。
CIFAR-10データセットでトレーニングします。ドロップアウトのあるパラメータを持つ大規模なモデルをトレーニングすると、検証セットの精度が 85% になります。パラメータを含む小さなモデルでは、80%の精度が得られます。
次に、大きいモデルから蒸留して小さなモデルにトレーニングを行うと、精度は 82%、精度は 2% 向上します。
72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10Configs
84class Configs(CIFAR10Configs):
スモールモデル
92 model: SmallModel
ラージモデル
94 large: LargeModel
ソフトターゲットのKLダイバージェンスロス
96 kl_div_loss = nn.KLDivLoss(log_target=True)
クロスエントロピー損失による真のラベルロス
98 loss_func = nn.CrossEntropyLoss()
温度、
100 temperature: float = 5.
ソフトターゲットの重量が減ります。
ソフトターゲットによって生成されるグラデーションは、次の方法でスケーリングされます。これを補うために、この論文ではソフトターゲットの損失を次の倍にスケーリングすることを提案しています
。106 soft_targets_weight: float = 100.
真のラベルクロスエントロピー損失を実現する重量
108 label_loss_weight: float = 0.5
110 def step(self, batch: any, batch_idx: BatchIndex):
小型モデルのトレーニング/評価モード
118 self.model.train(self.mode.is_train)
評価モードの大型モデル
120 self.large.eval()
データをデバイスに移動
123 data, target = batch[0].to(self.device), batch[1].to(self.device)
トレーニングモード時にグローバルステップ (処理されたサンプル数) を更新
126 if self.mode.is_train:
127 tracker.add_global_step(len(data))
ラージモデルから出力ロジットを取得
130 with torch.no_grad():
131 large_logits = self.large(data)
小さいモデルから出力ロジットを取得
134 output = self.model(data)
ソフトターゲット
138 soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)
小型モデルの温度調整済み確率
141 soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)
ソフトターゲットの損失の計算
144 soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)
実際のラベルロスを計算
146 label_loss = self.loss_func(output, target)
2 つの損失の加重和
148 loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss
損失を記録する
150 tracker.add({"loss.kl_div.": soft_targets_loss,
151 "loss.nll": label_loss,
152 "loss.": loss})
精度の計算と記録
155 self.accuracy(output, target)
156 self.accuracy.track()
モデルのトレーニング
159 if self.mode.is_train:
勾配の計算
161 loss.backward()
最適化の一歩を踏み出す
163 self.optimizer.step()
各エポックの最後のバッチでモデルパラメータと勾配を記録します
165 if batch_idx.is_last:
166 tracker.add('model', self.model)
グラデーションをクリア
168 self.optimizer.zero_grad()
追跡したメトリクスを保存する
171 tracker.save()
174@option(Configs.large)
175def _large_model(c: Configs):
179 return LargeModel().to(c.device)
182@option(Configs.model)
183def _small_student_model(c: Configs):
187 return SmallModel().to(c.device)
190def get_saved_model(run_uuid: str, checkpoint: int):
195 from labml_nn.distillation.large import Configs as LargeConfigs
評価モード (記録なし)
198 experiment.evaluate()
ラージモデルトレーニング実験のコンフィグを初期化
200 conf = LargeConfigs()
保存した設定を読み込む
202 experiment.configs(conf, experiment.load_configs(run_uuid))
保存/読み込み用のモデルを設定
204 experiment.add_pytorch_models({'model': conf.model})
どのランとチェックポイントをロードするかを設定
206 experiment.load(run_uuid, checkpoint)
実験を開始します。これでモデルが読み込まれ、すべての準備が整います
208 experiment.start()
モデルを返却する
211 return conf.model
蒸留による小型モデルのトレーニング
214def main(run_uuid: str, checkpoint: int):
保存したモデルをロード
219 large_model = get_saved_model(run_uuid, checkpoint)
実験を作成
221 experiment.create(name='distillation', comment='cifar10')
構成の作成
223 conf = Configs()
読み込んだラージモデルを設定する
225 conf.large = large_model
構成をロード
227 experiment.configs(conf, {
228 'optimizer.optimizer': 'Adam',
229 'optimizer.learning_rate': 2.5e-4,
230 'model': '_small_student_model',
231 })
保存/読み込み用のモデルを設定
233 experiment.add_pytorch_models({'model': conf.model})
実験をゼロから始める
235 experiment.load(None, None)
実験を開始し、トレーニングループを実行します
237 with experiment.start():
238 conf.run()
242if __name__ == '__main__':
243 main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)