# Evidential Deep Learning to Quantify Classification Uncertainty

This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is $1$. Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying “I don’t know”.

If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and an overall uncertainty mass $u \ge 0$ to all classes.

Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$ and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$. Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters $\color{orange}{\alpha_k} = e_k + 1$, and $\color{orange}{\alpha_0} = S = \sum_{k=1}^K \color{orange}{\alpha_k}$ is known as the Dirichlet strength. Dirichlet distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class $k$ is $\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$.

We get the model to output evidences for a given input $\mathbf{x}$. We use a function such as ReLU or a Softplus at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code experiment.py to train a model on MNIST dataset.

54import torch
55
56from labml import tracker
57from labml_helpers.module import Module

## Type II Maximum Likelihood Loss

The distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a prior on the likelihood $Multi(\mathbf{y} \vert p)$, and the negative log marginal likelihood is calculated by integrating over class probabilities $\mathbf{p}$.

If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,

60class MaximumLikelihoodLoss(Module):
• evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
• target is $\mathbf{y}$ with shape [batch_size, n_classes]
84    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

90        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

92        strength = alpha.sum(dim=-1)

Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{orange}{\alpha_k} \bigg)$

95        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

Mean loss over the batch

98        return loss.mean()

## Bayes Risk with Cross Entropy Loss

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$

We integrate this cost over all $\mathbf{p}$

where $\psi(\cdot)$ is the $digamma$ function.

101class CrossEntropyBayesRisk(Module):
• evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
• target is $\mathbf{y}$ with shape [batch_size, n_classes]
130    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

136        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

138        strength = alpha.sum(dim=-1)

Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{orange}{\alpha_k} ) \bigg)$

141        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

Mean loss over the batch

144        return loss.mean()

## Bayes Risk with Squared Error Loss

Here the cost function is squared error,

We integrate this cost over all $\mathbf{p}$

Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.

This gives,

This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and the second part is the variance.

147class SquaredErrorBayesRisk(Module):
• evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
• target is $\mathbf{y}$ with shape [batch_size, n_classes]
191    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

197        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

199        strength = alpha.sum(dim=-1)

$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$

201        p = alpha / strength[:, None]

Error $(y_k -\hat{p}_k)^2$

204        err = (target - p) ** 2

Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$

206        var = p * (1 - p) / (strength[:, None] + 1)

Sum of them

209        loss = (err + var).sum(dim=-1)

Mean loss over the batch

212        return loss.mean()

## KL Divergence Regularization Loss

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{orange}{\alpha_k}$ the Dirichlet parameters after remove the correct evidence.

where $\Gamma(\cdot)$ is the gamma function, $\psi(\cdot)$ is the $digamma$ function and $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$

215class KLDivergenceLoss(Module):
• evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
• target is $\mathbf{y}$ with shape [batch_size, n_classes]
238    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

244        alpha = evidence + 1.

Number of classes

246        n_classes = evidence.shape[-1]

249        alpha_tilde = target + (1 - target) * alpha

$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$

251        strength_tilde = alpha_tilde.sum(dim=-1)

The first term

261        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
262                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
263                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

The second term

268        second = (
269                (alpha_tilde - 1) *
270                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
271        ).sum(dim=-1)

Sum of the terms

274        loss = first + second

Mean loss over the batch

277        return loss.mean()

### Track statistics

This module computes statistics and tracks them with labml tracker.

280class TrackStatistics(Module):
287    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

Number of classes

289        n_classes = evidence.shape[-1]

Predictions that correctly match with the target (greedy sampling based on highest probability)

291        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

Track accuracy

293        tracker.add('accuracy.', match.sum() / match.shape[0])

$\color{orange}{\alpha_k} = e_k + 1$

296        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

298        strength = alpha.sum(dim=-1)

$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$

301        expected_probability = alpha / strength[:, None]

Expected probability of the selected (greedy highset probability) class

303        expected_probability, _ = expected_probability.max(dim=-1)

Uncertainty mass $u = \frac{K}{S}$

306        uncertainty_mass = n_classes / strength

Track $u$ for correctly predictions

309        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

Track $u$ for incorrect predictions

311        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

Track $\hat{p}_k$ for correctly predictions

313        tracker.add('prob.succ.', expected_probability.masked_select(match))

Track $\hat{p}_k$ for incorrect predictions

315        tracker.add('prob.fail.', expected_probability.masked_select(~match))