用证据深度学习来量化分类不确定性

这是 P yTorch 对《量化分类不确定性的证据深度学习》论文的实现。

Dampster-Shafer 证据理论为信仰群体分配了一组类别(与为单个类别分配概率不同)。所有子集的质量总和为。个别类别的概率(合理性)可以从这些质量中推导出来。

为所有类别的集合分配质量意味着它可以是任何一个类别;即说 “我不知道”。

如果有等级,我们为每个等级分配质量,为所有类别分配总体不确定性质量。

信仰众多,可以根据证据计算得出,无论在何处。Paper 使用术语证据来衡量从数据中收集的支持量,支持将样本归类为特定类别。

这对应于带有参数的狄利克雷分布,被称为狄利克雷强度。狄利克雷分布是类别分布上的分布;也就是说,你可以从狄利克雷分布中抽取类概率。上课的预期概率

我们让模型输出给定输入的证据。我们在最后一层使用诸如 ReLUSoftplus 之类的函数来获取

本文提出了一些用于训练模型的损失函数,我们在下面实现了这些函数。

这是在 MNIST 数据集上训练模型的训练代码experiment.py

52import torch
53
54from labml import tracker
55from labml_helpers.module import Module

类型 II 最大似然损失

分布是似然的先验,负对数边际似然是通过积分类概率来计算的

如果目标概率(一热目标)是针对给定样本的,则损失为,

58class MaximumLikelihoodLoss(Module):
  • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 83    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    89        alpha = evidence + 1.

    91        strength = alpha.sum(dim=-1)

    亏损

    94        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

    整个批次的平均损失

    97        return loss.mean()

    交叉熵损失的贝叶斯风险

    贝叶斯风险是做出错误估算的总体最大成本。它采用一个成本函数,该函数给出了做出错误估计的成本,并根据概率分布将其与所有可能的结果相加。

    这里的代价函数是交叉熵损失,用于一次热编码

    我们整合了这个成本

    函数在哪里。

    100class CrossEntropyBayesRisk(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 130    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    136        alpha = evidence + 1.

    138        strength = alpha.sum(dim=-1)

    亏损

    141        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

    整个批次的平均损失

    144        return loss.mean()

    误差损失平方时的贝叶斯风险

    这里的成本函数是平方误差,

    我们整合了这个成本

    从狄利克雷分布采样时的预期概率在哪里,方差哪里。

    这给了,

    方程的第一部分是误差项,第二部分是方差。

    147class SquaredErrorBayesRisk(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 193    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    199        alpha = evidence + 1.

    201        strength = alpha.sum(dim=-1)

    203        p = alpha / strength[:, None]

    错误

    206        err = (target - p) ** 2

    方差

    208        var = p * (1 - p) / (strength[:, None] + 1)

    它们的总和

    211        loss = (err + var).sum(dim=-1)

    整个批次的平均损失

    214        return loss.mean()

    KL 背离正则化损失

    如果样本无法正确分类,这会试图将总证据缩小为零。

    首先,我们在移除正确的证据后计算狄利克雷参数。

    其中是 gamma 函数,函数和

    217class KLDivergenceLoss(Module):
    • evidence有形状的[batch_size, n_classes]
  • target有形状的[batch_size, n_classes]
  • 241    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    247        alpha = evidence + 1.

    班级数

    249        n_classes = evidence.shape[-1]

    移除非误导性证据

    252        alpha_tilde = target + (1 - target) * alpha

    254        strength_tilde = alpha_tilde.sum(dim=-1)

    第一学期

    265        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
    266                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
    267                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

    第二学期

    272        second = (
    273                (alpha_tilde - 1) *
    274                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
    275        ).sum(dim=-1)

    条款总和

    278        loss = first + second

    整个批次的平均损失

    281        return loss.mean()

    追踪统计数据

    该模块计算统计数据并使用 labml 对其进行跟踪tracker

    284class TrackStatistics(Module):
    292    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

    班级数

    294        n_classes = evidence.shape[-1]

    与目标正确匹配的预测(基于最高概率的贪婪抽样)

    296        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

    轨道精度

    298        tracker.add('accuracy.', match.sum() / match.shape[0])

    301        alpha = evidence + 1.

    303        strength = alpha.sum(dim=-1)

    306        expected_probability = alpha / strength[:, None]

    所选(贪婪的最高概率)类别的预期概率

    308        expected_probability, _ = expected_probability.max(dim=-1)

    不确定性质量

    311        uncertainty_mass = n_classes / strength

    追踪正确的预测

    314        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

    追踪错误的预测

    316        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

    追踪正确的预测

    318        tracker.add('prob.succ.', expected_probability.masked_select(match))

    追踪错误的预测

    320        tracker.add('prob.fail.', expected_probability.masked_select(~match))