これにより、エビデンシャルディープラーニングに基づくモデルをトレーニングして、MNISTデータセットの分類の不確実性を定量化します。
14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_helpers.module import Module
22from labml_helpers.schedule import Schedule, RelativePiecewise
23from labml_helpers.train_valid import BatchIndex
24from labml_nn.experiments.mnist import MNISTConfigs
25from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
26 CrossEntropyBayesRisk, SquaredErrorBayesRisk
29class Model(Module):
34 def __init__(self, dropout: float):
35 super().__init__()
最初の畳み込み層
37 self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
ReLU アクティベーション
39 self.act1 = nn.ReLU()
マックスプーリング
41 self.max_pool1 = nn.MaxPool2d(2, 2)
2 番目の畳み込み層
43 self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
ReLU アクティベーション
45 self.act2 = nn.ReLU()
マックスプーリング
47 self.max_pool2 = nn.MaxPool2d(2, 2)
フィーチャにマッピングされる最初の完全接続レイヤー
49 self.fc1 = nn.Linear(50 * 4 * 4, 500)
ReLU アクティベーション
51 self.act3 = nn.ReLU()
クラスのエビデンスを出力するための最後の完全接続レイヤー。これにReLUまたはSoftplusアクティベーションをモデル外で適用すると、非陰性エビデンスが得られます
。55 self.fc2 = nn.Linear(500, 10)
隠しレイヤーのドロップアウト
57 self.dropout = nn.Dropout(p=dropout)
x
形状のMNIST画像のバッチです [batch_size, 1, 28, 28]
59 def __call__(self, x: torch.Tensor):
最初のコンボリューションとマックスプーリングを適用します。結果には形があります [batch_size, 20, 12, 12]
65 x = self.max_pool1(self.act1(self.conv1(x)))
2 回目のコンボリューションと最大プーリングを適用します。結果には形があります [batch_size, 50, 4, 4]
68 x = self.max_pool2(self.act2(self.conv2(x)))
テンソルを平らにして形を整える [batch_size, 50 * 4 * 4]
70 x = x.view(x.shape[0], -1)
隠しレイヤーを適用
72 x = self.act3(self.fc1(x))
ドロップアウトを適用
74 x = self.dropout(x)
最終レイヤーを適用して戻る
76 return self.fc2(x)
79class Configs(MNISTConfigs):
87 kl_div_loss = KLDivergenceLoss()
KL ダイバージェンス正則化係数スケジュール
89 kl_div_coef: Schedule
KL ダイバージェンス正則化係数スケジュール
91 kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
93 stats = TrackStatistics()
ドロップアウト
95 dropout: float = 0.5
モデル出力をゼロ以外のエビデンスに変換するモジュール
97 outputs_to_evidence: Module
99 def init(self):
トラッカー構成を設定
104 tracker.set_scalar("loss.*", True)
105 tracker.set_scalar("accuracy.*", True)
106 tracker.set_histogram('u.*', True)
107 tracker.set_histogram('prob.*', False)
108 tracker.set_scalar('annealing_coef.*', False)
109 tracker.set_scalar('kl_div_loss.*', False)
112 self.state_modules = []
114 def step(self, batch: Any, batch_idx: BatchIndex):
トレーニング/評価モード
120 self.model.train(self.mode.is_train)
データをデバイスに移動
123 data, target = batch[0].to(self.device), batch[1].to(self.device)
ワンホットコーディングターゲット
126 eye = torch.eye(10).to(torch.float).to(self.device)
127 target = eye[target]
トレーニングモード時にグローバルステップ (処理されたサンプル数) を更新
130 if self.mode.is_train:
131 tracker.add_global_step(len(data))
モデル出力を取得
134 outputs = self.model(data)
証拠を取得
136 evidence = self.outputs_to_evidence(outputs)
損失の計算
139 loss = self.loss_func(evidence, target)
KL ダイバージェンス正則化損失の計算
141 kl_div_loss = self.kl_div_loss(evidence, target)
142 tracker.add("loss.", loss)
143 tracker.add("kl_div_loss.", kl_div_loss)
KL 発散損失係数
146 annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
147 tracker.add("annealing_coef.", annealing_coef)
総損失
150 loss = loss + annealing_coef * kl_div_loss
トラック統計
153 self.stats(evidence, target)
モデルのトレーニング
156 if self.mode.is_train:
勾配の計算
158 loss.backward()
最適化の一歩を踏み出す
160 self.optimizer.step()
グラデーションをクリア
162 self.optimizer.zero_grad()
追跡したメトリクスを保存する
165 tracker.save()
168@option(Configs.model)
169def mnist_model(c: Configs):
173 return Model(c.dropout).to(c.device)
176@option(Configs.kl_div_coef)
177def kl_div_coef(c: Configs):
183 return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
189calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
191calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
エビデンスの計算には ReLU
194calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
証拠計算用ソフトプラス
196calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
199def main():
実験を作成
201 experiment.create(name='evidence_mnist')
構成の作成
203 conf = Configs()
構成をロード
205 experiment.configs(conf, {
206 'optimizer.optimizer': 'Adam',
207 'optimizer.learning_rate': 0.001,
208 'optimizer.weight_decay': 0.005,
'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
212 'loss_func': 'squared_error_bayes_risk',
213
214 'outputs_to_evidence': 'softplus',
215
216 'dropout': 0.5,
217 })
実験を開始し、トレーニングループを実行します
219 with experiment.start():
220 conf.run()
224if __name__ == '__main__':
225 main()