This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is $1$. Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".

If there are $K$ classes, we assign masses $b_{k}≥0$ to each of the classes and an overall uncertainty mass $u≥0$ to all classes.

$u+k=1∑K b_{k}=1$

Belief masses $b_{k}$ and $u$ can be computed from evidence $e_{k}≥0$, as $b_{k}=Se_{k} $ and $u=SK $ where $S=∑_{k=1}(e_{k}+1)$. Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters $α_{k}=e_{k}+1$, and $α_{0}=S=∑_{k=1}α_{k}$ is known as the Dirichlet strength. Dirichlet distribution $D(p∣α)$ is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class $k$ is $p^ _{k}=Sα_{k} $.

We get the model to output evidences $e=α−1=f(x∣Θ)$ for a given input $x$. We use a function such as ReLU or a Softplus at the final layer to get $f(x∣Θ)≥0$.

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code `experiment.py`

to train a model on MNIST dataset.

```
54import torch
55
56from labml import tracker
57from labml_helpers.module import Module
```

The distribution $D(p∣α)$ is a prior on the likelihood $Multi(y∣p)$, and the negative log marginal likelihood is calculated by integrating over class probabilities $p$.

If target probabilities (one-hot targets) are $y_{k}$ for a given sample the loss is,

$L(Θ) =−g(∫k=1∏K p_{k}B(α)1 k=1∏K p_{k}dp)=k=1∑K y_{k}(gS−gα_{k}) $`60class MaximumLikelihoodLoss(Module):`

`evidence`

is $e≥0$ with shape`[batch_size, n_classes]`

`target`

is $y$ with shape`[batch_size, n_classes]`

`85 def forward(self, evidence: torch.Tensor, target: torch.Tensor):`

$α_{k}=e_{k}+1$

`91 alpha = evidence + 1.`

$S=∑_{k=1}α_{k}$

`93 strength = alpha.sum(dim=-1)`

Losses $L(Θ)=∑_{k=1}y_{k}(gS−gα_{k})$

`96 loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)`

Mean loss over the batch

`99 return loss.mean()`

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded $y$ $k=1∑K −y_{k}gp_{k}$

We integrate this cost over all $p$

$L(Θ) =−g(∫[k=1∑K −y_{k}gp_{k}]B(α)1 k=1∏K p_{k}dp)=k=1∑K y_{k}(ψ(S)−ψ(α_{k})) $where $ψ(⋅)$ is the $digamma$ function.

`102class CrossEntropyBayesRisk(Module):`

`evidence`

is $e≥0$ with shape`[batch_size, n_classes]`

`target`

is $y$ with shape`[batch_size, n_classes]`

`132 def forward(self, evidence: torch.Tensor, target: torch.Tensor):`

$α_{k}=e_{k}+1$

`138 alpha = evidence + 1.`

$S=∑_{k=1}α_{k}$

`140 strength = alpha.sum(dim=-1)`

Losses $L(Θ)=∑_{k=1}y_{k}(ψ(S)−ψ(α_{k}))$

`143 loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)`

Mean loss over the batch

`146 return loss.mean()`

Here the cost function is squared error, $k=1∑K (y_{k}−p_{k})_{2}=∥y−p∥_{2}$

We integrate this cost over all $p$

$L(Θ) =−g(∫[k=1∑K (y_{k}−p_{k})_{2}]B(α)1 k=1∏K p_{k}dp)=k=1∑K E[y_{k}−2y_{k}p_{k}+p_{k}]=k=1∑K (y_{k}−2y_{k}E[p_{k}]+E[p_{k}]) $Where $E[p_{k}]=p^ _{k}=Sα_{k} $ is the expected probability when sampled from the Dirichlet distribution and $E[p_{k}]=E[p_{k}]_{2}+Var(p_{k})$ where $Var(p_{k})=S_{2}(S+1)α_{k}(S−α_{k}) =S+1p^ _{k}(1−p^ _{k}) $ is the variance.

This gives,

$L(Θ) =k=1∑K (y_{k}−2y_{k}E[p_{k}]+E[p_{k}])=k=1∑K (y_{k}−2y_{k}E[p_{k}]+E[p_{k}]_{2}+Var(p_{k}))=k=1∑K ((y_{k}−E[p_{k}])_{2}+Var(p_{k}))=k=1∑K ((y_{k}−p^ _{k})_{2}+S+1p^ _{k}(1−p^ _{k}) ) $This first part of the equation $(y_{k}−E[p_{k}])_{2}$ is the error term and the second part is the variance.

`149class SquaredErrorBayesRisk(Module):`

`evidence`

is $e≥0$ with shape`[batch_size, n_classes]`

`target`

is $y$ with shape`[batch_size, n_classes]`

`195 def forward(self, evidence: torch.Tensor, target: torch.Tensor):`

$α_{k}=e_{k}+1$

`201 alpha = evidence + 1.`

$S=∑_{k=1}α_{k}$

`203 strength = alpha.sum(dim=-1)`

$p^ _{k}=Sα_{k} $

`205 p = alpha / strength[:, None]`

Error $(y_{k}−p^ _{k})_{2}$

`208 err = (target - p) ** 2`

Variance $Var(p_{k})=S+1p^ _{k}(1−p^ _{k}) $

`210 var = p * (1 - p) / (strength[:, None] + 1)`

Sum of them

`213 loss = (err + var).sum(dim=-1)`

Mean loss over the batch

`216 return loss.mean()`

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate $α~_{k}=y_{k}+(1−y_{k})α_{k}$ the Dirichlet parameters after remove the correct evidence.

$ KL[D(p∣α~)∥∥ D(p∣<1,…,1>]=g(Γ(K)∏_{k=1}Γ(α~_{k})Γ(∑_{k=1}α~_{k}) )+k=1∑K (α~_{k}−1)[ψ(α~_{k})−ψ(S~)] $where $Γ(⋅)$ is the gamma function, $ψ(⋅)$ is the $digamma$ function and $S~=∑_{k=1}α~_{k}$

`219class KLDivergenceLoss(Module):`

`evidence`

is $e≥0$ with shape`[batch_size, n_classes]`

`target`

is $y$ with shape`[batch_size, n_classes]`

`243 def forward(self, evidence: torch.Tensor, target: torch.Tensor):`

$α_{k}=e_{k}+1$

`249 alpha = evidence + 1.`

Number of classes

`251 n_classes = evidence.shape[-1]`

Remove non-misleading evidence $α~_{k}=y_{k}+(1−y_{k})α_{k}$

`254 alpha_tilde = target + (1 - target) * alpha`

$S~=∑_{k=1}α~_{k}$

`256 strength_tilde = alpha_tilde.sum(dim=-1)`

The first term

$ g(Γ(K)∏_{k=1}Γ(α~_{k})Γ(∑_{k=1}α~_{k}) )=gΓ(k=1∑K α~_{k})−gΓ(K)−k=1∑K gΓ(α~_{k}) $

```
267 first = (torch.lgamma(alpha_tilde.sum(dim=-1))
268 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
269 - (torch.lgamma(alpha_tilde)).sum(dim=-1))
```

The second term $k=1∑K (α~_{k}−1)[ψ(α~_{k})−ψ(S~)]$

```
274 second = (
275 (alpha_tilde - 1) *
276 (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
277 ).sum(dim=-1)
```

Sum of the terms

`280 loss = first + second`

Mean loss over the batch

`283 return loss.mean()`

`286class TrackStatistics(Module):`

`294 def forward(self, evidence: torch.Tensor, target: torch.Tensor):`

Number of classes

`296 n_classes = evidence.shape[-1]`

Predictions that correctly match with the target (greedy sampling based on highest probability)

`298 match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))`

Track accuracy

`300 tracker.add('accuracy.', match.sum() / match.shape[0])`

$α_{k}=e_{k}+1$

`303 alpha = evidence + 1.`

$S=∑_{k=1}α_{k}$

`305 strength = alpha.sum(dim=-1)`

$p^ _{k}=Sα_{k} $

`308 expected_probability = alpha / strength[:, None]`

Expected probability of the selected (greedy highset probability) class

`310 expected_probability, _ = expected_probability.max(dim=-1)`

Uncertainty mass $u=SK $

`313 uncertainty_mass = n_classes / strength`

Track $u$ for correctly predictions

`316 tracker.add('u.succ.', uncertainty_mass.masked_select(match))`

Track $u$ for incorrect predictions

`318 tracker.add('u.fail.', uncertainty_mass.masked_select(~match))`

Track $p^ _{k}$ for correctly predictions

`320 tracker.add('prob.succ.', expected_probability.masked_select(match))`

Track $p^ _{k}$ for incorrect predictions

`322 tracker.add('prob.fail.', expected_probability.masked_select(~match))`