MNIST Experiment

11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs

Trainer configurations

24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):

Optimizer

32    optimizer: torch.optim.Adam

Training device

34    device: torch.device = DeviceConfigs()

Classification model

37    model: Module

Number of epochs to train for

39    epochs: int = 10

Number of times to switch between training and validation within an epoch

42    inner_iterations = 10

Accuracy function

45    accuracy = Accuracy()

Loss function

47    loss_func = nn.CrossEntropyLoss()

Initialization

49    def init(self):

Set tracker configurations

54        tracker.set_scalar("loss.*", True)
55        tracker.set_scalar("accuracy.*", True)

Add a hook to log module outputs

57        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

62        self.state_modules = [self.accuracy]

Training or validation step

64    def step(self, batch: any, batch_idx: BatchIndex):

Training/Evaluation mode

70        self.model.train(self.mode.is_train)

Move data to the device

73        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of samples processed) when in training mode

76        if self.mode.is_train:
77            tracker.add_global_step(len(data))

Whether to capture model outputs

80        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs.

82            output = self.model(data)

Calculate and log loss

85        loss = self.loss_func(output, target)
86        tracker.add("loss.", loss)

Calculate and log accuracy

89        self.accuracy(output, target)
90        self.accuracy.track()

Train the model

93        if self.mode.is_train:

Calculate gradients

95            loss.backward()

Take optimizer step

97            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

99            if batch_idx.is_last:
100                tracker.add('model', self.model)

Clear the gradients

102            self.optimizer.zero_grad()

Save the tracked metrics

105        tracker.save()

Default optimizer configurations

108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):
113    opt_conf = OptimizerConfigs()
114    opt_conf.parameters = c.model.parameters()
115    opt_conf.optimizer = 'Adam'
116    return opt_conf