Evidential Deep Learning to Quantify Classification Uncertainty

This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is . Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".

If there are classes, we assign masses to each of the classes and an overall uncertainty mass to all classes.

Belief masses and can be computed from evidence , as and where . Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters , and is known as the Dirichlet strength. Dirichlet distribution is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class is .

We get the model to output evidences for a given input . We use a function such as ReLU or a Softplus at the final layer to get .

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code experiment.py to train a model on MNIST dataset.

View Run

54import torch
55
56from labml import tracker
57from labml_helpers.module import Module

Type II Maximum Likelihood Loss

The distribution is a prior on the likelihood , and the negative log marginal likelihood is calculated by integrating over class probabilities .

If target probabilities (one-hot targets) are for a given sample the loss is,

60class MaximumLikelihoodLoss(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
85    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

91        alpha = evidence + 1.

93        strength = alpha.sum(dim=-1)

Losses

96        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

Mean loss over the batch

99        return loss.mean()

Bayes Risk with Cross Entropy Loss

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded

We integrate this cost over all

where is the function.

102class CrossEntropyBayesRisk(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
132    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

138        alpha = evidence + 1.

140        strength = alpha.sum(dim=-1)

Losses

143        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

Mean loss over the batch

146        return loss.mean()

Bayes Risk with Squared Error Loss

Here the cost function is squared error,

We integrate this cost over all

Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.

This gives,

This first part of the equation is the error term and the second part is the variance.

149class SquaredErrorBayesRisk(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
195    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

201        alpha = evidence + 1.

203        strength = alpha.sum(dim=-1)

205        p = alpha / strength[:, None]

Error

208        err = (target - p) ** 2

Variance

210        var = p * (1 - p) / (strength[:, None] + 1)

Sum of them

213        loss = (err + var).sum(dim=-1)

Mean loss over the batch

216        return loss.mean()

KL Divergence Regularization Loss

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate the Dirichlet parameters after remove the correct evidence.

where is the gamma function, is the function and

219class KLDivergenceLoss(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
243    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

249        alpha = evidence + 1.

Number of classes

251        n_classes = evidence.shape[-1]

Remove non-misleading evidence

254        alpha_tilde = target + (1 - target) * alpha

256        strength_tilde = alpha_tilde.sum(dim=-1)

The first term

267        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
268                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
269                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

The second term

274        second = (
275                (alpha_tilde - 1) *
276                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
277        ).sum(dim=-1)

Sum of the terms

280        loss = first + second

Mean loss over the batch

283        return loss.mean()

Track statistics

This module computes statistics and tracks them with labml tracker .

286class TrackStatistics(Module):
294    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

Number of classes

296        n_classes = evidence.shape[-1]

Predictions that correctly match with the target (greedy sampling based on highest probability)

298        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

Track accuracy

300        tracker.add('accuracy.', match.sum() / match.shape[0])

303        alpha = evidence + 1.

305        strength = alpha.sum(dim=-1)

308        expected_probability = alpha / strength[:, None]

Expected probability of the selected (greedy highset probability) class

310        expected_probability, _ = expected_probability.max(dim=-1)

Uncertainty mass

313        uncertainty_mass = n_classes / strength

Track for correctly predictions

316        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

Track for incorrect predictions

318        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

Track for correctly predictions

320        tracker.add('prob.succ.', expected_probability.masked_select(match))

Track for incorrect predictions

322        tracker.add('prob.fail.', expected_probability.masked_select(~match))