Evidential Deep Learning to Quantify Classification Uncertainty

This is a PyTorch implementation of the paper Evidential Deep Learning to Quantify Classification Uncertainty.

Dampster-Shafer Theory of Evidence assigns belief masses a set of classes (unlike assigning a probability to a single class). Sum of the masses of all subsets is . Individual class probabilities (plausibilities) can be derived from these masses.

Assigning a mass to the set of all classes means it can be any one of the classes; i.e. saying "I don't know".

If there are classes, we assign masses to each of the classes and an overall uncertainty mass to all classes.

Belief masses and can be computed from evidence , as and where . Paper uses term evidence as a measure of the amount of support collected from data in favor of a sample to be classified into a certain class.

This corresponds to a Dirichlet distribution with parameters , and is known as the Dirichlet strength. Dirichlet distribution is a distribution over categorical distribution; i.e. you can sample class probabilities from a Dirichlet distribution. The expected probability for class is .

We get the model to output evidences for a given input . We use a function such as ReLU or a Softplus at the final layer to get .

The paper proposes a few loss functions to train the model, which we have implemented below.

Here is the training code experiment.py to train a model on MNIST dataset.

52import torch
53
54from labml import tracker
55from labml_helpers.module import Module

Type II Maximum Likelihood Loss

The distribution is a prior on the likelihood , and the negative log marginal likelihood is calculated by integrating over class probabilities .

If target probabilities (one-hot targets) are for a given sample the loss is,

58class MaximumLikelihoodLoss(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
83    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

89        alpha = evidence + 1.

91        strength = alpha.sum(dim=-1)

Losses

94        loss = (target * (strength.log()[:, None] - alpha.log())).sum(dim=-1)

Mean loss over the batch

97        return loss.mean()

Bayes Risk with Cross Entropy Loss

Bayes risk is the overall maximum cost of making incorrect estimates. It takes a cost function that gives the cost of making an incorrect estimate and sums it over all possible outcomes based on probability distribution.

Here the cost function is cross-entropy loss, for one-hot coded

We integrate this cost over all

where is the function.

100class CrossEntropyBayesRisk(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
130    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

136        alpha = evidence + 1.

138        strength = alpha.sum(dim=-1)

Losses

141        loss = (target * (torch.digamma(strength)[:, None] - torch.digamma(alpha))).sum(dim=-1)

Mean loss over the batch

144        return loss.mean()

Bayes Risk with Squared Error Loss

Here the cost function is squared error,

We integrate this cost over all

Where is the expected probability when sampled from the Dirichlet distribution and where is the variance.

This gives,

This first part of the equation is the error term and the second part is the variance.

147class SquaredErrorBayesRisk(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
193    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

199        alpha = evidence + 1.

201        strength = alpha.sum(dim=-1)

203        p = alpha / strength[:, None]

Error

206        err = (target - p) ** 2

Variance

208        var = p * (1 - p) / (strength[:, None] + 1)

Sum of them

211        loss = (err + var).sum(dim=-1)

Mean loss over the batch

214        return loss.mean()

KL Divergence Regularization Loss

This tries to shrink the total evidence to zero if the sample cannot be correctly classified.

First we calculate the Dirichlet parameters after remove the correct evidence.

where is the gamma function, is the function and

217class KLDivergenceLoss(Module):
  • evidence is with shape [batch_size, n_classes]
  • target is with shape [batch_size, n_classes]
241    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

247        alpha = evidence + 1.

Number of classes

249        n_classes = evidence.shape[-1]

Remove non-misleading evidence

252        alpha_tilde = target + (1 - target) * alpha

254        strength_tilde = alpha_tilde.sum(dim=-1)

The first term

265        first = (torch.lgamma(alpha_tilde.sum(dim=-1))
266                 - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
267                 - (torch.lgamma(alpha_tilde)).sum(dim=-1))

The second term

272        second = (
273                (alpha_tilde - 1) *
274                (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
275        ).sum(dim=-1)

Sum of the terms

278        loss = first + second

Mean loss over the batch

281        return loss.mean()

Track statistics

This module computes statistics and tracks them with labml tracker .

284class TrackStatistics(Module):
292    def forward(self, evidence: torch.Tensor, target: torch.Tensor):

Number of classes

294        n_classes = evidence.shape[-1]

Predictions that correctly match with the target (greedy sampling based on highest probability)

296        match = evidence.argmax(dim=-1).eq(target.argmax(dim=-1))

Track accuracy

298        tracker.add('accuracy.', match.sum() / match.shape[0])

301        alpha = evidence + 1.

303        strength = alpha.sum(dim=-1)

306        expected_probability = alpha / strength[:, None]

Expected probability of the selected (greedy highset probability) class

308        expected_probability, _ = expected_probability.max(dim=-1)

Uncertainty mass

311        uncertainty_mass = n_classes / strength

Track for correctly predictions

314        tracker.add('u.succ.', uncertainty_mass.masked_select(match))

Track for incorrect predictions

316        tracker.add('u.fail.', uncertainty_mass.masked_select(~match))

Track for correctly predictions

318        tracker.add('prob.succ.', expected_probability.masked_select(match))

Track for incorrect predictions

320        tracker.add('prob.fail.', expected_probability.masked_select(~match))