9import torch.nn as nn
10import torch.utils.data
11
12from labml import experiment, tracker
13from labml.configs import option
14from labml_nn.helpers.datasets import MNISTConfigs
15from labml_nn.helpers.device import DeviceConfigs
16from labml_nn.helpers.metrics import Accuracy
17from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
18from labml_nn.optimizers.configs import OptimizerConfigs
21class Model(nn.Module):
26 def __init__(self):
27 super().__init__()
28 self.conv1 = nn.Conv2d(1, 20, 5, 1)
29 self.pool1 = nn.MaxPool2d(2)
30 self.conv2 = nn.Conv2d(20, 50, 5, 1)
31 self.pool2 = nn.MaxPool2d(2)
32 self.fc1 = nn.Linear(16 * 50, 500)
33 self.fc2 = nn.Linear(500, 10)
34 self.activation = nn.ReLU()
36 def forward(self, x):
37 x = self.activation(self.conv1(x))
38 x = self.pool1(x)
39 x = self.activation(self.conv2(x))
40 x = self.pool2(x)
41 x = self.activation(self.fc1(x.view(-1, 16 * 50)))
42 return self.fc2(x)
45class Configs(MNISTConfigs, TrainValidConfigs):
49 optimizer: torch.optim.Adam
50 model: nn.Module
51 device: torch.device = DeviceConfigs()
52 epochs: int = 10
53
54 is_save_models = True
55 model: nn.Module
56 inner_iterations = 10
57
58 accuracy_func = Accuracy()
59 loss_func = nn.CrossEntropyLoss()
61 def init(self):
62 tracker.set_queue("loss.*", 20, True)
63 tracker.set_scalar("accuracy.*", True)
64 self.state_modules = [self.accuracy_func]
66 def step(self, batch: any, batch_idx: BatchIndex):
Get the batch
68 data, target = batch[0].to(self.device), batch[1].to(self.device)
Add global step if we are in training mode
71 if self.mode.is_train:
72 tracker.add_global_step(len(data))
Run the model
75 output = self.model(data)
Calculate the loss
78 loss = self.loss_func(output, target)
Calculate the accuracy
80 self.accuracy_func(output, target)
Log the loss
82 tracker.add("loss.", loss)
Optimize if we are in training mode
85 if self.mode.is_train:
Calculate the gradients
87 loss.backward()
Take optimizer step
90 self.optimizer.step()
Log the parameter and gradient L2 norms once per epoch
92 if batch_idx.is_last:
93 tracker.add('model', self.model)
94 tracker.add('optimizer', (self.optimizer, {'model': self.model}))
Clear the gradients
96 self.optimizer.zero_grad()
Save logs
99 tracker.save()
Create a configurable optimizer. We can change the optimizer type and hyper-parameters using configurations.
102@option(Configs.model)
103def model(c: Configs):
104 return Model().to(c.device)
105
106
107@option(Configs.optimizer)
108def _optimizer(c: Configs):
113 opt_conf = OptimizerConfigs()
114 opt_conf.parameters = c.model.parameters()
115 return opt_conf
118def main():
119 conf = Configs()
120 conf.inner_iterations = 10
121 experiment.create(name='mnist_ada_belief')
122 experiment.configs(conf, {'inner_iterations': 10,
Specify the optimizer
124 'optimizer.optimizer': 'Adam',
125 'optimizer.learning_rate': 1.5e-4})
126 experiment.add_pytorch_models(dict(model=conf.model))
127 with experiment.start():
128 conf.run()
129
130
131if __name__ == '__main__':
132 main()