9import torch.nn as nn
10import torch.utils.data
11
12from labml import experiment, tracker
13from labml.configs import option
14from labml_helpers.datasets.mnist import MNISTConfigs
15from labml_helpers.device import DeviceConfigs
16from labml_helpers.metrics.accuracy import Accuracy
17from labml_helpers.seed import SeedConfigs
18from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
19from labml_nn.optimizers.configs import OptimizerConfigs22class Model(nn.Module):26 def __init__(self):
27 super().__init__()
28 self.conv1 = nn.Conv2d(1, 20, 5, 1)
29 self.pool1 = nn.MaxPool2d(2)
30 self.conv2 = nn.Conv2d(20, 50, 5, 1)
31 self.pool2 = nn.MaxPool2d(2)
32 self.fc1 = nn.Linear(16 * 50, 500)
33 self.fc2 = nn.Linear(500, 10)
34 self.activation = nn.ReLU()36 def forward(self, x):
37 x = self.activation(self.conv1(x))
38 x = self.pool1(x)
39 x = self.activation(self.conv2(x))
40 x = self.pool2(x)
41 x = self.activation(self.fc1(x.view(-1, 16 * 50)))
42 return self.fc2(x)45class Configs(MNISTConfigs, TrainValidConfigs):49 optimizer: torch.optim.Adam
50 model: nn.Module
51 set_seed = SeedConfigs()
52 device: torch.device = DeviceConfigs()
53 epochs: int = 10
54
55 is_save_models = True
56 model: nn.Module
57 inner_iterations = 10
58
59 accuracy_func = Accuracy()
60 loss_func = nn.CrossEntropyLoss()62 def init(self):
63 tracker.set_queue("loss.*", 20, True)
64 tracker.set_scalar("accuracy.*", True)
65 hook_model_outputs(self.mode, self.model, 'model')
66 self.state_modules = [self.accuracy_func]68 def step(self, batch: any, batch_idx: BatchIndex):Get the batch
70 data, target = batch[0].to(self.device), batch[1].to(self.device)Add global step if we are in training mode
73 if self.mode.is_train:
74 tracker.add_global_step(len(data))Run the model and specify whether to log the activations
77 with self.mode.update(is_log_activations=batch_idx.is_last):
78 output = self.model(data)Calculate the loss
81 loss = self.loss_func(output, target)Calculate the accuracy
83 self.accuracy_func(output, target)Log the loss
85 tracker.add("loss.", loss)Optimize if we are in training mode
88 if self.mode.is_train:Calculate the gradients
90 loss.backward()Take optimizer step
93 self.optimizer.step()Log the parameter and gradient L2 norms once per epoch
95 if batch_idx.is_last:
96 tracker.add('model', self.model)
97 tracker.add('optimizer', (self.optimizer, {'model': self.model}))Clear the gradients
99 self.optimizer.zero_grad()Save logs
102 tracker.save()Create a configurable optimizer. We can change the optimizer type and hyper-parameters using configurations.
105@option(Configs.model)
106def model(c: Configs):
107 return Model().to(c.device)
108
109
110@option(Configs.optimizer)
111def _optimizer(c: Configs):116 opt_conf = OptimizerConfigs()
117 opt_conf.parameters = c.model.parameters()
118 return opt_conf121def main():
122 conf = Configs()
123 conf.inner_iterations = 10
124 experiment.create(name='mnist_ada_belief')
125 experiment.configs(conf, {'inner_iterations': 10,Specify the optimizer
127 'optimizer.optimizer': 'Adam',
128 'optimizer.learning_rate': 1.5e-4})
129 conf.set_seed.set()
130 experiment.add_pytorch_models(dict(model=conf.model))
131 with experiment.start():
132 conf.run()
133
134
135if __name__ == '__main__':
136 main()