This trains a model based on Evidential Deep Learning to Quantify Classification Uncertainty on MNIST dataset.
14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_helpers.module import Module
22from labml_helpers.schedule import Schedule, RelativePiecewise
23from labml_helpers.train_valid import BatchIndex
24from labml_nn.experiments.mnist import MNISTConfigs
25from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
26 CrossEntropyBayesRisk, SquaredErrorBayesRisk
29class Model(Module):
34 def __init__(self, dropout: float):
35 super().__init__()
First convolution layer
37 self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
ReLU activation
39 self.act1 = nn.ReLU()
max-pooling
41 self.max_pool1 = nn.MaxPool2d(2, 2)
Second convolution layer
43 self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
ReLU activation
45 self.act2 = nn.ReLU()
max-pooling
47 self.max_pool2 = nn.MaxPool2d(2, 2)
First fully-connected layer that maps to features
49 self.fc1 = nn.Linear(50 * 4 * 4, 500)
ReLU activation
51 self.act3 = nn.ReLU()
Final fully connected layer to output evidence for classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence
55 self.fc2 = nn.Linear(500, 10)
Dropout for the hidden layer
57 self.dropout = nn.Dropout(p=dropout)
x
is the batch of MNIST images of shape [batch_size, 1, 28, 28]
59 def __call__(self, x: torch.Tensor):
Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]
65 x = self.max_pool1(self.act1(self.conv1(x)))
Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]
68 x = self.max_pool2(self.act2(self.conv2(x)))
Flatten the tensor to shape [batch_size, 50 * 4 * 4]
70 x = x.view(x.shape[0], -1)
Apply hidden layer
72 x = self.act3(self.fc1(x))
Apply dropout
74 x = self.dropout(x)
Apply final layer and return
76 return self.fc2(x)
79class Configs(MNISTConfigs):
87 kl_div_loss = KLDivergenceLoss()
KL Divergence regularization coefficient schedule
89 kl_div_coef: Schedule
KL Divergence regularization coefficient schedule
91 kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]
Stats module for tracking
93 stats = TrackStatistics()
Dropout
95 dropout: float = 0.5
Module to convert the model output to non-zero evidences
97 outputs_to_evidence: Module
99 def init(self):
Set tracker configurations
104 tracker.set_scalar("loss.*", True)
105 tracker.set_scalar("accuracy.*", True)
106 tracker.set_histogram('u.*', True)
107 tracker.set_histogram('prob.*', False)
108 tracker.set_scalar('annealing_coef.*', False)
109 tracker.set_scalar('kl_div_loss.*', False)
112 self.state_modules = []
114 def step(self, batch: Any, batch_idx: BatchIndex):
Training/Evaluation mode
120 self.model.train(self.mode.is_train)
Move data to the device
123 data, target = batch[0].to(self.device), batch[1].to(self.device)
One-hot coded targets
126 eye = torch.eye(10).to(torch.float).to(self.device)
127 target = eye[target]
Update global step (number of samples processed) when in training mode
130 if self.mode.is_train:
131 tracker.add_global_step(len(data))
Get model outputs
134 outputs = self.model(data)
Get evidences
136 evidence = self.outputs_to_evidence(outputs)
Calculate loss
139 loss = self.loss_func(evidence, target)
Calculate KL Divergence regularization loss
141 kl_div_loss = self.kl_div_loss(evidence, target)
142 tracker.add("loss.", loss)
143 tracker.add("kl_div_loss.", kl_div_loss)
KL Divergence loss coefficient
146 annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
147 tracker.add("annealing_coef.", annealing_coef)
Total loss
150 loss = loss + annealing_coef * kl_div_loss
Track statistics
153 self.stats(evidence, target)
Train the model
156 if self.mode.is_train:
Calculate gradients
158 loss.backward()
Take optimizer step
160 self.optimizer.step()
Clear the gradients
162 self.optimizer.zero_grad()
Save the tracked metrics
165 tracker.save()
168@option(Configs.model)
169def mnist_model(c: Configs):
173 return Model(c.dropout).to(c.device)
176@option(Configs.kl_div_coef)
177def kl_div_coef(c: Configs):
Create a relative piecewise schedule
183 return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
187calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
189calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
191calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())
ReLU to calculate evidence
194calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())
Softplus to calculate evidence
196calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
199def main():
Create experiment
201 experiment.create(name='evidence_mnist')
Create configurations
203 conf = Configs()
Load configurations
205 experiment.configs(conf, {
206 'optimizer.optimizer': 'Adam',
207 'optimizer.learning_rate': 0.001,
208 'optimizer.weight_decay': 0.005,
'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',
212 'loss_func': 'squared_error_bayes_risk',
213
214 'outputs_to_evidence': 'softplus',
215
216 'dropout': 0.5,
217 })
Start the experiment and run the training loop
219 with experiment.start():
220 conf.run()
224if __name__ == '__main__':
225 main()