Evidential Deep Learning to Quantify Classification Uncertainty

This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is $1$. Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying “I don’t know”.

If there are $K$ classes, we assign masses $b_k \ge 0$ to each of the classes and an overall uncertainty mass $u \ge 0$ to all classes.

Belief masses $b_k$ and $u$ can be computed from evidence $e_k \ge 0$, as $b_k = \frac{e_k}{S}$ and $u = \frac{K}{S}$ where $S = \sum_{k=1}^K (e_k + 1)$. Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters $\color{orange}{\alpha_k} = e_k + 1$, and $\color{orange}{\alpha_0} = S = \sum_{k=1}^K \color{orange}{\alpha_k}$ is known as the Dirichlet strength. Dirichlet distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class $k$ is $\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$.

We get the model to output evidences for a given input $\mathbf{x}$. We use a function such as ReLU or a Softplus at the final layer to get $f(\mathbf{x} | \Theta) \ge 0$.

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code experiment.py to train a model on MNIST dataset.

View Run

54import torch
55
56from labml import tracker
57from labml_helpers.module import Module

Type II Maximum Likelihood Loss

The distribution $D(\mathbf{p} \vert \color{orange}{\mathbf{\alpha}})$ is a prior on the likelihood $Multi(\mathbf{y} \vert p)$, and the negative log marginal likelihood is calculated by integrating over class probabilities $\mathbf{p}$.

If target probabilities (one-hot targets) are $y_k$ for a given sample the loss is,

60class MaximumLikelihoodLoss(Module):
  • evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
  • target is $\mathbf{y}$ with shape [batch_size, n_classes]
84    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

90        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

92        strength = alpha.sum(dim=-1)

Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \log S - \log \color{orange}{\alpha_k} \bigg)$

95        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

Mean loss over the batch

98        return loss.mean()

Bayes Risk with Cross Entropy Loss

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded $\mathbf{y}$

We integrate this cost over all $\mathbf{p}$

where $\psi(\cdot)$ is the $digamma$ function.

101class CrossEntropyBayesRisk(Module):
  • evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
  • target is $\mathbf{y}$ with shape [batch_size, n_classes]
130    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

136        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

138        strength = alpha.sum(dim=-1)

Losses $\mathcal{L}(\Theta) = \sum_{k=1}^K y_k \bigg( \psi(S) - \psi( \color{orange}{\alpha_k} ) \bigg)$

141        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

Mean loss over the batch

144        return loss.mean()

Bayes Risk with Squared Error Loss

Here the cost function is squared error,

We integrate this cost over all $\mathbf{p}$

Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.

This gives,

This first part of the equation $\big(y_k -\mathbb{E}[p_k]\big)^2$ is the error term and the second part is the variance.

147class SquaredErrorBayesRisk(Module):
  • evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
  • target is $\mathbf{y}$ with shape [batch_size, n_classes]
191    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

197        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

199        strength = alpha.sum(dim=-1)

$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$

201        p = alpha / strength[:, None]

Error $(y_k -\hat{p}_k)^2$

204        err = (target - p) ** 2

Variance $\text{Var}(p_k) = \frac{\hat{p}_k(1 - \hat{p}_k)}{S + 1}$

206        var = p * (1 - p) / (strength[:, None] + 1)

Sum of them

209        loss = (err + var).sum(dim=-1)

Mean loss over the batch

212        return loss.mean()

KL Divergence Regularization Loss

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate $\tilde{\alpha}_k = y_k + (1 - y_k) \color{orange}{\alpha_k}$ the Dirichlet parameters after remove the correct evidence.

where $\Gamma(\cdot)$ is the gamma function, $\psi(\cdot)$ is the $digamma$ function and $\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$

215class KLDivergenceLoss(Module):
  • evidence is $\mathbf{e} \ge 0$ with shape [batch_size, n_classes]
  • target is $\mathbf{y}$ with shape [batch_size, n_classes]
238    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

$\color{orange}{\alpha_k} = e_k + 1$

244        alpha = evidence + 1.

Number of classes

246        n_classes = evidence.shape[-1]

Remove non-misleading evidence

249        alpha_tilde = target + (1 - target) * alpha

$\tilde{S} = \sum_{k=1}^K \tilde{\alpha}_k$

251        strength_tilde = alpha_tilde.sum(dim=-1)

The first term

261        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
262                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
263                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

The second term

268        second = (
269                (alpha_tilde - 1) *
270                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
271        ).sum(dim=-1)

Sum of the terms

274        loss = first + second

Mean loss over the batch

277        return loss.mean()

Track statistics

This module computes statistics and tracks them with labml tracker.

280class TrackStatistics(Module):
287    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

Number of classes

289        n_classes = evidence.shape[-1]

Predictions that correctly match with the target (greedy sampling based on highest probability)

291        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

Track accuracy

293        tracker.add('accuracy.', match.sum() / match.shape[0])

$\color{orange}{\alpha_k} = e_k + 1$

296        alpha = evidence + 1.

$S = \sum_{k=1}^K \color{orange}{\alpha_k}$

298        strength = alpha.sum(dim=-1)

$\hat{p}_k = \frac{\color{orange}{\alpha_k}}{S}$

301        expected_probability = alpha / strength[:, None]

Expected probability of the selected (greedy highset probability) class

303        expected_probability, _ = expected_probability.max(dim=-1)

Uncertainty mass $u = \frac{K}{S}$

306        uncertainty_mass = n_classes / strength

Track $u$ for correctly predictions

309        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

Track $u$ for incorrect predictions

311        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

Track $\hat{p}_k$ for correctly predictions

313        tracker.add('prob.succ.', expected_probability.masked_select(match))

Track $\hat{p}_k$ for incorrect predictions

315        tracker.add('prob.fail.', expected_probability.masked_select(~match))