11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs
24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):
优化器
32 optimizer: torch.optim.Adam
训练设备
34 device: torch.device = DeviceConfigs()
分类模型
37 model: Module
要训练的时代数
39 epochs: int = 10
一个纪元内在训练和验证之间切换的次数
42 inner_iterations = 10
精度函数
45 accuracy = Accuracy()
亏损函数
47 loss_func = nn.CrossEntropyLoss()
49 def init(self):
设置跟踪器配置
54 tracker.set_scalar("loss.*", True)
55 tracker.set_scalar("accuracy.*", True)
向日志模块输出添加钩子
57 hook_model_outputs(self.mode, self.model, 'model')
增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。
62 self.state_modules = [self.accuracy]
64 def step(self, batch: any, batch_idx: BatchIndex):
训练/评估模式
70 self.model.train(self.mode.is_train)
将数据移动到设备
73 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的样本数)
76 if self.mode.is_train:
77 tracker.add_global_step(len(data))
是否捕获模型输出
80 with self.mode.update(is_log_activations=batch_idx.is_last):
获取模型输出。
82 output = self.model(data)
计算并记录损失
85 loss = self.loss_func(output, target)
86 tracker.add("loss.", loss)
计算和记录精度
89 self.accuracy(output, target)
90 self.accuracy.track()
训练模型
93 if self.mode.is_train:
计算梯度
95 loss.backward()
采取优化器步骤
97 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
99 if batch_idx.is_last:
100 tracker.add('model', self.model)
清除渐变
102 self.optimizer.zero_grad()
保存跟踪的指标
105 tracker.save()
108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):
113 opt_conf = OptimizerConfigs()
114 opt_conf.parameters = c.model.parameters()
115 opt_conf.optimizer = 'Adam'
116 return opt_conf