9import torch.nn as nn
10import torch.utils.data
11from labml_helpers.module import Module
12
13from labml import experiment, tracker
14from labml.configs import option
15from labml_helpers.datasets.mnist import MNISTConfigs
16from labml_helpers.device import DeviceConfigs
17from labml_helpers.metrics.accuracy import Accuracy
18from labml_helpers.seed import SeedConfigs
19from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
20from labml_nn.optimizers.configs import OptimizerConfigs
23class Model(Module):
27 def __init__(self):
28 super().__init__()
29 self.conv1 = nn.Conv2d(1, 20, 5, 1)
30 self.pool1 = nn.MaxPool2d(2)
31 self.conv2 = nn.Conv2d(20, 50, 5, 1)
32 self.pool2 = nn.MaxPool2d(2)
33 self.fc1 = nn.Linear(16 * 50, 500)
34 self.fc2 = nn.Linear(500, 10)
35 self.activation = nn.ReLU()
37 def forward(self, x):
38 x = self.activation(self.conv1(x))
39 x = self.pool1(x)
40 x = self.activation(self.conv2(x))
41 x = self.pool2(x)
42 x = self.activation(self.fc1(x.view(-1, 16 * 50)))
43 return self.fc2(x)
46class Configs(MNISTConfigs, TrainValidConfigs):
50 optimizer: torch.optim.Adam
51 model: nn.Module
52 set_seed = SeedConfigs()
53 device: torch.device = DeviceConfigs()
54 epochs: int = 10
55
56 is_save_models = True
57 model: nn.Module
58 inner_iterations = 10
59
60 accuracy_func = Accuracy()
61 loss_func = nn.CrossEntropyLoss()
63 def init(self):
64 tracker.set_queue("loss.*", 20, True)
65 tracker.set_scalar("accuracy.*", True)
66 hook_model_outputs(self.mode, self.model, 'model')
67 self.state_modules = [self.accuracy_func]
69 def step(self, batch: any, batch_idx: BatchIndex):
获取批次
71 data, target = batch[0].to(self.device), batch[1].to(self.device)
如果我们处于训练模式,则添加全局步长
74 if self.mode.is_train:
75 tracker.add_global_step(len(data))
运行模型并指定是否记录激活
78 with self.mode.update(is_log_activations=batch_idx.is_last):
79 output = self.model(data)
计算损失
82 loss = self.loss_func(output, target)
计算精度
84 self.accuracy_func(output, target)
记录损失
86 tracker.add("loss.", loss)
如果我们处于训练模式,请进行优化
89 if self.mode.is_train:
计算梯度
91 loss.backward()
采取优化器步骤
94 self.optimizer.step()
每个纪元记录一次参数和梯度 L2 规范
96 if batch_idx.is_last:
97 tracker.add('model', self.model)
98 tracker.add('optimizer', (self.optimizer, {'model': self.model}))
清除渐变
100 self.optimizer.zero_grad()
保存日志
103 tracker.save()
创建可配置的优化器。我们可以使用配置更改优化器类型和超参数。
106@option(Configs.model)
107def model(c: Configs):
108 return Model().to(c.device)
109
110
111@option(Configs.optimizer)
112def _optimizer(c: Configs):
117 opt_conf = OptimizerConfigs()
118 opt_conf.parameters = c.model.parameters()
119 return opt_conf
122def main():
123 conf = Configs()
124 conf.inner_iterations = 10
125 experiment.create(name='mnist_ada_belief')
126 experiment.configs(conf, {'inner_iterations': 10,
指定优化器
128 'optimizer.optimizer': 'Adam',
129 'optimizer.learning_rate': 1.5e-4})
130 conf.set_seed.set()
131 experiment.add_pytorch_models(dict(model=conf.model))
132 with experiment.start():
133 conf.run()
134
135
136if __name__ == '__main__':
137 main()