これは、分類の不確実性を定量化するための論文「エビデンシャル・ディープ・ラーニング」をPyTorchで実装したものです。
Dampster-Shaferの証拠理論では、(単一のクラスに確率を割り当てるのとは異なります)ビリーフマスに一連のクラスを割り当てます。すべてのサブセットの質量の合計はです。これらの質量から個々のクラスの確率(妥当性)を導き出すことができます。
すべてのクラスの集合に質量を割り当てると、どのクラスでもかまいません。つまり、「わからない」ということです。
クラスがある場合は、各クラスに質量を割り当て、すべてのクラスに全体的な不確定性質量を割り当てます。
信念は集まるので、その場で、証拠から計算できます。論文では、特定のクラスに分類されるサンプルに有利なように、データから収集された支持量の尺度として用語エビデンスを用いています。
これはパラメータを含むディリクレ分布に対応し、ディリクレ強度として知られています。ディリクレ分布はカテゴリ分布にわたる分布です。つまり、ディリクレ分布からクラス確率をサンプリングできます。クラスの予想確率はです。
与えられた入力に対してエビデンスを出力するようにモデルを取得します。
最後のレイヤーでReLUやSoftplusなどの関数を使って取得します。この論文では、モデルをトレーニングするための損失関数をいくつか提案しています。これを以下に実装しました。
52import torch
53
54from labml import tracker
55from labml_helpers.module import Module
分布は確率の前提条件であり、負の対数限界確率は、クラス全体の確率を積分して計算されます。
ターゲット確率(ワンホットターゲット)が特定のサンプルの場合、損失は次のようになります。
58class MaximumLikelihoodLoss(Module):
evidence
形付きです [batch_size, n_classes]
target
形付きです [batch_size, n_classes]
83 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
89 alpha = evidence + 1.
91 strength = alpha.sum(dim=-1)
損失
94 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
バッチ全体の平均損失
97 return loss.mean()
ベイズリスクとは、誤った推定を行うことによる全体的な最大コストです。これは、誤った見積もりを行った場合のコストを計算し、確率分布に基づいて考えられるすべての結果を合計するコスト関数を取ります
。ここで、コスト関数はワンホットコーディングのクロスエントロピー損失です。
この費用を全体として統合します
関数はどこですか。
100class CrossEntropyBayesRisk(Module):
evidence
形付きです [batch_size, n_classes]
target
形付きです [batch_size, n_classes]
130 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
136 alpha = evidence + 1.
138 strength = alpha.sum(dim=-1)
損失
141 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
バッチ全体の平均損失
144 return loss.mean()
ここで、コスト関数は二乗誤差です。
この費用を全体として統合します
ディリクレ分布からサンプリングしたときの期待確率はどこで、どこは分散です。
これにより、
方程式のこの最初の部分は誤差項で、2 番目の部分は分散です。
147class SquaredErrorBayesRisk(Module):
evidence
形付きです [batch_size, n_classes]
target
形付きです [batch_size, n_classes]
193 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
199 alpha = evidence + 1.
201 strength = alpha.sum(dim=-1)
203 p = alpha / strength[:, None]
エラー
206 err = (target - p) ** 2
差異
208 var = p * (1 - p) / (strength[:, None] + 1)
それらの合計
211 loss = (err + var).sum(dim=-1)
バッチ全体の平均損失
214 return loss.mean()
これにより、サンプルが正しく分類できない場合に、エビデンスの総数をゼロに減らそうとします。
まず、正しい証拠を取り除いた後、Dirichletパラメーターを計算します。
ここで、はガンマ関数、は関数、
217class KLDivergenceLoss(Module):
evidence
形付きです [batch_size, n_classes]
target
形付きです [batch_size, n_classes]
241 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
247 alpha = evidence + 1.
クラス数
249 n_classes = evidence.shape[-1]
誤解を招かない証拠の削除
252 alpha_tilde = target + (1 - target) * alpha
254 strength_tilde = alpha_tilde.sum(dim=-1)
265 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
266 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
267 - (torch.lgamma(alpha_tilde)).sum(dim=-1))
第二学期
272 second = (
273 (alpha_tilde - 1) *
274 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
275 ).sum(dim=-1)
条件の合計
278 loss = first + second
バッチ全体の平均損失
281 return loss.mean()
284class TrackStatistics(Module):
292 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
クラス数
294 n_classes = evidence.shape[-1]
ターゲットと正確に一致する予測(最も高い確率に基づく貪欲なサンプリング)
296 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
トラック精度
298 tracker.add('accuracy.', match.sum() / match.shape[0])
301 alpha = evidence + 1.
303 strength = alpha.sum(dim=-1)
306 expected_probability = alpha / strength[:, None]
選択した(欲張り最高確率)クラスの予想確率
308 expected_probability, _ = expected_probability.max(dim=-1)
不確実性の質量
311 uncertainty_mass = n_classes / strength
追跡して正しく予測できるようにする
314 tracker.add('u.succ.', uncertainty_mass.masked_select(match))
予測の誤りがないか追跡する
316 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
追跡して正しく予測できるようにする
318 tracker.add('prob.succ.', expected_probability.masked_select(match))
予測の誤りがないか追跡する
320 tracker.add('prob.fail.', expected_probability.masked_select(~match))