分類の不確実性を定量化するエビデンシャルディープラーニング

これは、分類の不確実性を定量化するための論文「エビデンシャル・ディープ・ラーニング」をPyTorchで実装したものです

Dampster-Shaferの証拠理論では、(単一のクラスに確率を割り当てるのとは異なります)ビリーフマスに一連のクラスを割り当てます。すべてのサブセットの質量の合計はです。これらの質量から個々のクラスの確率(妥当性)を導き出すことができます。

すべてのクラスの集合に質量を割り当てると、どのクラスでもかまいません。つまり、「わからない」ということです。

クラスがある場合は、各クラスに質量を割り当て、すべてのクラスに全体的な不確定性質量を割り当てます。

信念は集まるので、その場で証拠から計算できます。論文では、特定のクラスに分類されるサンプルに有利なように、データから収集された支持量の尺度として用語エビデンスを用いています。

これはパラメータを含むディリクレ分布に対応しディリクレ強度として知られています。ディリクレ分布はカテゴリ分布にわたる分布です。つまり、ディリクレ分布からクラス確率をサンプリングできます。クラスの予想確率はです

与えられた入力に対してエビデンスを出力するようにモデルを取得します。

最後のレイヤーでReLUやSoftplusなどの関数を使って取得します

この論文では、モデルをトレーニングするための損失関数をいくつか提案しています。これを以下に実装しました。

これは、experiment.py MNISTデータセットでモデルをトレーニングするためのトレーニングコードです

52import torch
53
54from labml import tracker
55from labml_helpers.module import Module

タイプ II 最可能性損失

分布は確率の前提条件であり、負の対数限界確率は、クラス全体の確率を積分して計算されます。

ターゲット確率(ワンホットターゲット)が特定のサンプルの場合、損失は次のようになります。

58class MaximumLikelihoodLoss(Module):
  • evidence 形付きです [batch_size, n_classes]
  • target 形付きです [batch_size, n_classes]
83    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

89        alpha = evidence + 1.

91        strength = alpha.sum(dim=-1)

損失

94        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

バッチ全体の平均損失

97        return loss.mean()

クロスエントロピー損失を伴うベイズリスク

ベイズリスクとは、誤った推定を行うことによる全体的な最大コストです。これは、誤った見積もりを行った場合のコストを計算し、確率分布に基づいて考えられるすべての結果を合計するコスト関数を取ります

ここで、コスト関数はワンホットコーディングのクロスエントロピー損失です。

この費用を全体として統合します

関数はどこですか。

100class CrossEntropyBayesRisk(Module):
  • evidence 形付きです [batch_size, n_classes]
  • target 形付きです [batch_size, n_classes]
130    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

136        alpha = evidence + 1.

138        strength = alpha.sum(dim=-1)

損失

141        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

バッチ全体の平均損失

144        return loss.mean()

二乗誤差損失によるベイズリスク

ここで、コスト関数は二乗誤差です。

この費用を全体として統合します

ディリクレ分布からサンプリングしたときの期待確率はどこで、どこは分散です。

これにより、

方程式のこの最初の部分は誤差項で、2 番目の部分は分散です。

147class SquaredErrorBayesRisk(Module):
  • evidence 形付きです [batch_size, n_classes]
  • target 形付きです [batch_size, n_classes]
193    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

199        alpha = evidence + 1.

201        strength = alpha.sum(dim=-1)

203        p = alpha / strength[:, None]

エラー

206        err = (target - p) ** 2

差異

208        var = p * (1 - p) / (strength[:, None] + 1)

それらの合計

211        loss = (err + var).sum(dim=-1)

バッチ全体の平均損失

214        return loss.mean()

KL ダイバージェンス正則化損失

これにより、サンプルが正しく分類できない場合に、エビデンスの総数をゼロに減らそうとします。

まず、正しい証拠を取り除いた後、Dirichletパラメーターを計算します。

ここで、はガンマ関数、は関数、

217class KLDivergenceLoss(Module):
  • evidence 形付きです [batch_size, n_classes]
  • target 形付きです [batch_size, n_classes]
241    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

247        alpha = evidence + 1.

クラス数

249        n_classes = evidence.shape[-1]

誤解を招かない証拠の削除

252        alpha_tilde = target + (1 - target) * alpha

254        strength_tilde = alpha_tilde.sum(dim=-1)

第一学期

265        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
266                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
267                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

第二学期

272        second = (
273                (alpha_tilde - 1) *
274                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
275        ).sum(dim=-1)

条件の合計

278        loss = first + second

バッチ全体の平均損失

281        return loss.mean()

トラック統計

このモジュールは統計を計算し、tracker labmlで追跡します

284class TrackStatistics(Module):
292    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

クラス数

294        n_classes = evidence.shape[-1]

ターゲットと正確に一致する予測(最も高い確率に基づく貪欲なサンプリング)

296        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

トラック精度

298        tracker.add('accuracy.', match.sum() / match.shape[0])

301        alpha = evidence + 1.

303        strength = alpha.sum(dim=-1)

306        expected_probability = alpha / strength[:, None]

選択した(欲張り最高確率)クラスの予想確率

308        expected_probability, _ = expected_probability.max(dim=-1)

不確実性の質量

311        uncertainty_mass = n_classes / strength

追跡して正しく予測できるようにする

314        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

予測の誤りがないか追跡する

316        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

追跡して正しく予測できるようにする

318        tracker.add('prob.succ.', expected_probability.masked_select(match))

予測の誤りがないか追跡する

320        tracker.add('prob.fail.', expected_probability.masked_select(~match))