这是 P yTorch 对《量化分类不确定性的证据深度学习》论文的实现。
Dampster-Shafer 证据理论为信仰群体分配了一组类别(与为单个类别分配概率不同)。所有子集的质量总和为。个别类别的概率(合理性)可以从这些质量中推导出来。
为所有类别的集合分配质量意味着它可以是任何一个类别;即说 “我不知道”。
如果有等级,我们为每个等级分配质量,为所有类别分配总体不确定性质量。
信仰众多,可以根据证据计算得出,无论在何处。Paper 使用术语证据来衡量从数据中收集的支持量,支持将样本归类为特定类别。
这对应于带有参数的狄利克雷分布,被称为狄利克雷强度。狄利克雷分布是类别分布上的分布;也就是说,你可以从狄利克雷分布中抽取类概率。上课的预期概率是。
我们让模型输出给定输入的证据。我们在最后一层使用诸如 ReLU 或 Softplus 之类的函数来获取。
本文提出了一些用于训练模型的损失函数,我们在下面实现了这些函数。
这是在 MNIST 数据集上训练模型的训练代码experiment.py
。
52import torch
53
54from labml import tracker
55from labml_helpers.module import Module
58class MaximumLikelihoodLoss(Module):
83 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
89 alpha = evidence + 1.
91 strength = alpha.sum(dim=-1)
亏损
94 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)
整个批次的平均损失
97 return loss.mean()
贝叶斯风险是做出错误估算的总体最大成本。它采用一个成本函数,该函数给出了做出错误估计的成本,并根据概率分布将其与所有可能的结果相加。
这里的代价函数是交叉熵损失,用于一次热编码
我们整合了这个成本
函数在哪里。
100class CrossEntropyBayesRisk(Module):
130 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
136 alpha = evidence + 1.
138 strength = alpha.sum(dim=-1)
亏损
141 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)
整个批次的平均损失
144 return loss.mean()
147class SquaredErrorBayesRisk(Module):
193 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
199 alpha = evidence + 1.
201 strength = alpha.sum(dim=-1)
203 p = alpha / strength[:, None]
错误
206 err = (target - p) ** 2
方差
208 var = p * (1 - p) / (strength[:, None] + 1)
它们的总和
211 loss = (err + var).sum(dim=-1)
整个批次的平均损失
214 return loss.mean()
217class KLDivergenceLoss(Module):
241 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
247 alpha = evidence + 1.
班级数
249 n_classes = evidence.shape[-1]
移除非误导性证据
252 alpha_tilde = target + (1 - target) * alpha
254 strength_tilde = alpha_tilde.sum(dim=-1)
265 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
266 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
267 - (torch.lgamma(alpha_tilde)).sum(dim=-1))
第二学期
272 second = (
273 (alpha_tilde - 1) *
274 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
275 ).sum(dim=-1)
条款总和
278 loss = first + second
整个批次的平均损失
281 return loss.mean()
292 def forward(self, evidence: torch.Tensor, target: torch.Tensor):
班级数
294 n_classes = evidence.shape[-1]
与目标正确匹配的预测(基于最高概率的贪婪抽样)
296 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))
轨道精度
298 tracker.add('accuracy.', match.sum() / match.shape[0])
301 alpha = evidence + 1.
303 strength = alpha.sum(dim=-1)
306 expected_probability = alpha / strength[:, None]
所选(贪婪的最高概率)类别的预期概率
308 expected_probability, _ = expected_probability.max(dim=-1)
不确定性质量
311 uncertainty_mass = n_classes / strength
追踪正确的预测
314 tracker.add('u.succ.', uncertainty_mass.masked_select(match))
追踪错误的预测
316 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))
追踪正确的预测
318 tracker.add('prob.succ.', expected_probability.masked_select(match))
追踪错误的预测
320 tracker.add('prob.fail.', expected_probability.masked_select(~match))