MNIST 实验

11import torch.nn as nn
12import torch.utils.data
13from labml_helpers.module import Module
14
15from labml import tracker
16from labml.configs import option
17from labml_helpers.datasets.mnist import MNISTConfigs as MNISTDatasetConfigs
18from labml_helpers.device import DeviceConfigs
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.train_valid import TrainValidConfigs, BatchIndex, hook_model_outputs
21from labml_nn.optimizers.configs import OptimizerConfigs

训练器配置

24class MNISTConfigs(MNISTDatasetConfigs, TrainValidConfigs):

优化器

32    optimizer: torch.optim.Adam

训练设备

34    device: torch.device = DeviceConfigs()

分类模型

37    model: Module

要训练的时代数

39    epochs: int = 10

一个纪元内在训练和验证之间切换的次数

42    inner_iterations = 10

精度函数

45    accuracy = Accuracy()

亏损函数

47    loss_func = nn.CrossEntropyLoss()

初始化

49    def init(self):

设置跟踪器配置

54        tracker.set_scalar("loss.*", True)
55        tracker.set_scalar("accuracy.*", True)

向日志模块输出添加钩子

57        hook_model_outputs(self.mode, self.model, 'model')

增加作为状态模块的精度。这个名字可能令人困惑,因为它旨在存储 RNN 的训练和验证之间的状态。这将使精度指标统计数据分开,以便进行训练和验证。

62        self.state_modules = [self.accuracy]

培训或验证步骤

64    def step(self, batch: any, batch_idx: BatchIndex):

训练/评估模式

70        self.model.train(self.mode.is_train)

将数据移动到设备

73        data, target = batch[0].to(self.device), batch[1].to(self.device)

在训练模式下更新全局步长(处理的样本数)

76        if self.mode.is_train:
77            tracker.add_global_step(len(data))

是否捕获模型输出

80        with self.mode.update(is_log_activations=batch_idx.is_last):

获取模型输出。

82            output = self.model(data)

计算并记录损失

85        loss = self.loss_func(output, target)
86        tracker.add("loss.", loss)

计算和记录精度

89        self.accuracy(output, target)
90        self.accuracy.track()

训练模型

93        if self.mode.is_train:

计算梯度

95            loss.backward()

采取优化器步骤

97            self.optimizer.step()

记录每个纪元最后一批的模型参数和梯度

99            if batch_idx.is_last:
100                tracker.add('model', self.model)

清除渐变

102            self.optimizer.zero_grad()

保存跟踪的指标

105        tracker.save()

默认优化器配置

108@option(MNISTConfigs.optimizer)
109def _optimizer(c: MNISTConfigs):
113    opt_conf = OptimizerConfigs()
114    opt_conf.parameters = c.model.parameters()
115    opt_conf.optimizer = 'Adam'
116    return opt_conf