This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.
14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_nn.helpers.schedule import Schedule, RelativePiecewise
22from labml_nn.helpers.trainer import BatchIndex
23from labml_nn.experiments.mnist import MNISTConfigs
24from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
25 CrossEntropyBayesRisk, SquaredErrorBayesRisk
28class Model(nn.Module):
33 def __init__(self, dropout: float):
34 super().__init__()
First convolution layer
36 self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
ReLU activation
38 self.act1 = nn.ReLU()
max-pooling
40 self.max_pool1 = nn.MaxPool2d(2, 2)
Second convolution layer
42 self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
ReLU activation
44 self.act2 = nn.ReLU()
max-pooling
46 self.max_pool2 = nn.MaxPool2d(2, 2)
First fully-connected layer that maps to features
48 self.fc1 = nn.Linear(50 * 4 * 4, 500)
ReLU activation
50 self.act3 = nn.ReLU()
Final fully connected layer to output evidence for classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence
54 self.fc2 = nn.Linear(500, 10)
Dropout for the hidden layer
56 self.dropout = nn.Dropout(p=dropout)
x
is the batch of MNIST images of shape [batch_size, 1, 28, 28]
58 def __call__(self, x: torch.Tensor):
Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]
64 x = self.max_pool1(self.act1(self.conv1(x)))
Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]
67 x = self.max_pool2(self.act2(self.conv2(x)))
Flatten the tensor to shape [batch_size, 50 * 4 * 4]
69 x = x.view(x.shape[0], -1)
Apply hidden layer
71 x = self.act3(self.fc1(x))
Apply dropout
73 x = self.dropout(x)
Apply final layer and return
75 return self.fc2(x)
78class Configs(MNISTConfigs):
86 kl_div_loss = KLDivergenceLoss()
KL Divergence regularization coefficient schedule
88 kl_div_coef: Schedule
KL Divergence regularization coefficient schedule
90 kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
Stats module for tracking
92 stats = TrackStatistics()
Dropout
94 dropout: float = 0.5
Module to convert the model output to non-zero evidences
96 outputs_to_evidence: nn.Module
98 def init(self):
Set tracker configurations
103 tracker.set_scalar("loss.*", True)
104 tracker.set_scalar("accuracy.*", True)
105 tracker.set_histogram('u.*', True)
106 tracker.set_histogram('prob.*', False)
107 tracker.set_scalar('annealing_coef.*', False)
108 tracker.set_scalar('kl_div_loss.*', False)
111 self.state_modules = []
113 def step(self, batch: Any, batch_idx: BatchIndex):
Training/Evaluation mode
119 self.model.train(self.mode.is_train)
Move data to the device
122 data, target = batch[0].to(self.device), batch[1].to(self.device)
One-hot coded targets
125 eye = torch.eye(10).to(torch.float).to(self.device)
126 target = eye[target]
Update global step (number of samples processed) when in training mode
129 if self.mode.is_train:
130 tracker.add_global_step(len(data))
Get model outputs
133 outputs = self.model(data)
Get evidences
135 evidence = self.outputs_to_evidence(outputs)
Calculate loss
138 loss = self.loss_func(evidence, target)
Calculate KL Divergence regularization loss
140 kl_div_loss = self.kl_div_loss(evidence, target)
141 tracker.add("loss.", loss)
142 tracker.add("kl_div_loss.", kl_div_loss)
KL Divergence loss coefficient
145 annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
146 tracker.add("annealing_coef.", annealing_coef)
Total loss
149 loss = loss + annealing_coef * kl_div_loss
Track statistics
152 self.stats(evidence, target)
Train the model
155 if self.mode.is_train:
Calculate gradients
157 loss.backward()
Take optimizer step
159 self.optimizer.step()
Clear the gradients
161 self.optimizer.zero_grad()
Save the tracked metrics
164 tracker.save()
167@option(Configs.model)
168def mnist_model(c: Configs):
172 return Model(c.dropout).to(c.device)
175@option(Configs.kl_div_coef)
176def kl_div_coef(c: Configs):
Create a relative piecewise schedule
182 return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
186calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
188calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
190calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
ReLU to calculate evidence
193calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
Softplus to calculate evidence
195calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
198def main():
Create experiment
200 experiment.create(name='evidence_mnist')
Create configurations
202 conf = Configs()
Load configurations
204 experiment.configs(conf, {
205 'optimizer.optimizer': 'Adam',
206 'optimizer.learning_rate': 0.001,
207 'optimizer.weight_decay': 0.005,
'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
211 'loss_func': 'squared_error_bayes_risk',
212
213 'outputs_to_evidence': 'softplus',
214
215 'dropout': 0.5,
216 })
Start the experiment and run the training loop
218 with experiment.start():
219 conf.run()
223if __name__ == '__main__':
224 main()