14from typing import Any
15
16import torch.nn as nn
17import torch.utils.data
18
19from labml import tracker, experiment
20from labml.configs import option, calculate
21from labml_nn.helpers.schedule import Schedule, RelativePiecewise
22from labml_nn.helpers.trainer import BatchIndex
23from labml_nn.experiments.mnist import MNISTConfigs
24from labml_nn.uncertainty.evidence import KLDivergenceLoss, TrackStatistics, MaximumLikelihoodLoss, \
25    CrossEntropyBayesRisk, SquaredErrorBayesRisk

LeNet based model fro MNIST classification

28class Model(nn.Module):
33    def __init__(self, dropout: float):
34        super().__init__()

First convolution layer

36        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)

ReLU activation

38        self.act1 = nn.ReLU()

max-pooling

40        self.max_pool1 = nn.MaxPool2d(2, 2)

Second convolution layer

42        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)

ReLU activation

44        self.act2 = nn.ReLU()

max-pooling

46        self.max_pool2 = nn.MaxPool2d(2, 2)

First fully-connected layer that maps to features

48        self.fc1 = nn.Linear(50 * 4 * 4, 500)

ReLU activation

50        self.act3 = nn.ReLU()

Final fully connected layer to output evidence for classes. The ReLU or Softplus activation is applied to this outside the model to get the non-negative evidence

54        self.fc2 = nn.Linear(500, 10)

Dropout for the hidden layer

56        self.dropout = nn.Dropout(p=dropout)
  • x is the batch of MNIST images of shape [batch_size, 1, 28, 28]
58    def __call__(self, x: torch.Tensor):

Apply first convolution and max pooling. The result has shape [batch_size, 20, 12, 12]

64        x = self.max_pool1(self.act1(self.conv1(x)))

Apply second convolution and max pooling. The result has shape [batch_size, 50, 4, 4]

67        x = self.max_pool2(self.act2(self.conv2(x)))

Flatten the tensor to shape [batch_size, 50 * 4 * 4]

69        x = x.view(x.shape[0], -1)

Apply hidden layer

71        x = self.act3(self.fc1(x))

Apply dropout

73        x = self.dropout(x)

Apply final layer and return

75        return self.fc2(x)

Configurations

We use MNISTConfigs configurations.

78class Configs(MNISTConfigs):
86    kl_div_loss = KLDivergenceLoss()

KL Divergence regularization coefficient schedule

88    kl_div_coef: Schedule

KL Divergence regularization coefficient schedule

90    kl_div_coef_schedule = [(0, 0.), (0.2, 0.01), (1, 1.)]

Stats module for tracking

92    stats = TrackStatistics()

Dropout

94    dropout: float = 0.5

Module to convert the model output to non-zero evidences

96    outputs_to_evidence: nn.Module

Initialization

98    def init(self):

Set tracker configurations

103        tracker.set_scalar("loss.*", True)
104        tracker.set_scalar("accuracy.*", True)
105        tracker.set_histogram('u.*', True)
106        tracker.set_histogram('prob.*', False)
107        tracker.set_scalar('annealing_coef.*', False)
108        tracker.set_scalar('kl_div_loss.*', False)

111        self.state_modules = []

Training or validation step

113    def step(self, batch: Any, batch_idx: BatchIndex):

Training/Evaluation mode

119        self.model.train(self.mode.is_train)

Move data to the device

122        data, target = batch[0].to(self.device), batch[1].to(self.device)

One-hot coded targets

125        eye = torch.eye(10).to(torch.float).to(self.device)
126        target = eye[target]

Update global step (number of samples processed) when in training mode

129        if self.mode.is_train:
130            tracker.add_global_step(len(data))

Get model outputs

133        outputs = self.model(data)

Get evidences

135        evidence = self.outputs_to_evidence(outputs)

Calculate loss

138        loss = self.loss_func(evidence, target)

Calculate KL Divergence regularization loss

140        kl_div_loss = self.kl_div_loss(evidence, target)
141        tracker.add("loss.", loss)
142        tracker.add("kl_div_loss.", kl_div_loss)

KL Divergence loss coefficient

145        annealing_coef = min(1., self.kl_div_coef(tracker.get_global_step()))
146        tracker.add("annealing_coef.", annealing_coef)

Total loss

149        loss = loss + annealing_coef * kl_div_loss

Track statistics

152        self.stats(evidence, target)

Train the model

155        if self.mode.is_train:

Calculate gradients

157            loss.backward()

Take optimizer step

159            self.optimizer.step()

Clear the gradients

161            self.optimizer.zero_grad()

Save the tracked metrics

164        tracker.save()

Create model

167@option(Configs.model)
168def mnist_model(c: Configs):
172    return Model(c.dropout).to(c.device)

KL Divergence Loss Coefficient Schedule

175@option(Configs.kl_div_coef)
176def kl_div_coef(c: Configs):
182    return RelativePiecewise(c.kl_div_coef_schedule, c.epochs * len(c.train_dataset))
186calculate(Configs.loss_func, 'max_likelihood_loss', lambda: MaximumLikelihoodLoss())
188calculate(Configs.loss_func, 'cross_entropy_bayes_risk', lambda: CrossEntropyBayesRisk())
190calculate(Configs.loss_func, 'squared_error_bayes_risk', lambda: SquaredErrorBayesRisk())

ReLU to calculate evidence

193calculate(Configs.outputs_to_evidence, 'relu', lambda: nn.ReLU())

Softplus to calculate evidence

195calculate(Configs.outputs_to_evidence, 'softplus', lambda: nn.Softplus())
198def main():

Create experiment

200    experiment.create(name='evidence_mnist')

Create configurations

202    conf = Configs()

Load configurations

204    experiment.configs(conf, {
205        'optimizer.optimizer': 'Adam',
206        'optimizer.learning_rate': 0.001,
207        'optimizer.weight_decay': 0.005,

'loss_func': 'max_likelihood_loss', 'loss_func': 'cross_entropy_bayes_risk',

211        'loss_func': 'squared_error_bayes_risk',
212
213        'outputs_to_evidence': 'softplus',
214
215        'dropout': 0.5,
216    })

Start the experiment and run the training loop

218    with experiment.start():
219        conf.run()

223if __name__ == '__main__':
224    main()