13from typing import Any
14
15import torch
16from torch import nn
17from torch.utils.data import DataLoader
18
19from labml import tracker, experiment
20from labml_helpers.metrics.accuracy import AccuracyDirect
21from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
22from labml_nn.adaptive_computation.parity import ParityDataset
23from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss
周期的数量
33 epochs: int = 100
每个纪元的批次数
35 n_batches: int = 500
批量大小
37 batch_size: int = 128
型号
40 model: ParityPonderGRU
43 loss_rec: ReconstructionLoss
45 loss_reg: RegularizationLoss
输入向量中的元素数。我们将其保持在较低的水平以进行演示;否则,训练会花费很多时间。尽管奇偶校验任务看起来很简单,但通过查看样本来找出模式相当困难。
51 n_elems: int = 8
隐藏层(状态)中的单位数量
53 n_hidden: int = 64
最大步数
55 max_steps: int = 20
用于几何分布
58 lambda_p: float = 0.2
正则化损失系数
60 beta: float = 0.01
按规范进行渐变裁剪
63 grad_norm_clip: float = 1.0
训练和验证装载机
66 train_loader: DataLoader
67 valid_loader: DataLoader
精度计算器
70 accuracy = AccuracyDirect()
72 def init(self):
将指示器打印到屏幕上
74 tracker.set_scalar('loss.*', True)
75 tracker.set_scalar('loss_reg.*', True)
76 tracker.set_scalar('accuracy.*', True)
77 tracker.set_scalar('steps.*', True)
我们需要设置指标来计算训练和验证时期的指标
80 self.state_modules = [self.accuracy]
初始化模型
83 self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)
85 self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)
87 self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)
训练和验证装载机
90 self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
91 batch_size=self.batch_size)
92 self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
93 batch_size=self.batch_size)
培训师会为每批次调用此方法
95 def step(self, batch: Any, batch_idx: BatchIndex):
设置模型模式
100 self.model.train(self.mode.is_train)
获取输入和标签并将其移动到模型的设备中
103 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式中增加步数
106 if self.mode.is_train:
107 tracker.add_global_step(len(data))
运行模型
110 p, y_hat, p_sampled, y_hat_sampled = self.model(data)
计算重建损失
113 loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
114 tracker.add("loss.", loss_rec)
计算正则化损失
117 loss_reg = self.loss_reg(p)
118 tracker.add("loss_reg.", loss_reg)
121 loss = loss_rec + self.beta * loss_reg
计算预期采取的步数
124 steps = torch.arange(1, p.shape[0] + 1, device=p.device)
125 expected_steps = (p * steps[:, None]).sum(dim=0)
126 tracker.add("steps.", expected_steps)
呼叫准确度指标
129 self.accuracy(y_hat_sampled > 0, target)
130
131 if self.mode.is_train:
计算梯度
133 loss.backward()
剪辑渐变
135 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
优化器
137 self.optimizer.step()
渐变清晰
139 self.optimizer.zero_grad()
141 tracker.save()
运行实验
144def main():
148 experiment.create(name='ponder_net')
149
150 conf = Configs()
151 experiment.configs(conf, {
152 'optimizer.optimizer': 'Adam',
153 'optimizer.learning_rate': 0.0003,
154 })
155
156 with experiment.start():
157 conf.run()
160if __name__ == '__main__':
161 main()