11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs
24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
Optimizer
32 optimizer: torch.optim.Adam
Training device
34 device: torch.device = DeviceConfigs()
Classification model
37 model: Module
Number of epochs to train for
39 epochs: int = 10
Number of times to switch between training and validation within an epoch
42 inner_iterations = 10
Accuracy function
45 accuracy = Accuracy()
Loss function
47 loss_func = nn.CrossEntropyLoss()
49 def init(self):
Set tracker configurations
54 tracker.set_scalar("loss.*", True)
55 tracker.set_scalar("accuracy.*", True)
Add a hook to log module outputs
57 hook_model_outputs(self.mode, self.model, 'model')
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
62 self.state_modules = [self.accuracy]
64 def step(self, batch: any, batch_idx: BatchIndex):
Training/Evaluation mode
70 self.model.train(self.mode.is_train)
Move data to the device
73 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of samples processed) when in training mode
76 if self.mode.is_train:
77 tracker.add_global_step(len(data))
Whether to capture model outputs
80 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs.
82 output = self.model(data)
Calculate and log loss
85 loss = self.loss_func(output, target)
86 tracker.add("loss.", loss)
Calculate and log accuracy
89 self.accuracy(output, target)
90 self.accuracy.track()
Train the model
93 if self.mode.is_train:
Calculate gradients
95 loss.backward()
Take optimizer step
97 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
99 if batch_idx.is_last:
100 tracker.add('model', self.model)
Clear the gradients
102 self.optimizer.zero_grad()
Save the tracked metrics
105 tracker.save()
108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):
113 opt_conf = OptimizerConfigs()
114 opt_conf.parameters = c.model.parameters()
115 opt_conf.optimizer = 'Adam'
116 return opt_conf