13from typing import Any
14
15import torch
16from torch import nn
17from torch.utils.data import DataLoader
18
19from labml import tracker, experiment
20from labml_helpers.metrics.accuracy import AccuracyDirect
21from labml_helpers.train_valid import SimpleTrainValidConfigs, BatchIndex
22from labml_nn.adaptive_computation.parity import ParityDataset
23from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLoss
26class Configs(SimpleTrainValidConfigs):

エポック数

33    epochs: int = 100

エポックあたりのバッチ数

35    n_batches: int = 500

バッチサイズ

37    batch_size: int = 128

モデル

40    model: ParityPonderGRU

43    loss_rec: ReconstructionLoss

45    loss_reg: RegularizationLoss

入力ベクトルの要素数。デモ用に低く設定しています。そうしないと、トレーニングに時間がかかります。パリティのタスクは簡単そうに見えますが、サンプルを見てパターンを理解するのはかなり難しいです

51    n_elems: int = 8

隠れ層 (状態) 内のユニット数

53    n_hidden: int = 64

最大ステップ数

55    max_steps: int = 20

幾何分布用

58    lambda_p: float = 0.2

正則化損失係数

60    beta: float = 0.01

標準によるグラデーションクリッピング

63    grad_norm_clip: float = 1.0

トレーニングおよび検証ローダー

66    train_loader: DataLoader
67    valid_loader: DataLoader

精度計算ツール

70    accuracy = AccuracyDirect()
72    def init(self):

インジケータを画面に印刷

74        tracker.set_scalar('loss.*', True)
75        tracker.set_scalar('loss_reg.*', True)
76        tracker.set_scalar('accuracy.*', True)
77        tracker.set_scalar('steps.*', True)

トレーニングと検証のために、エポックに合わせてそれらを計算するメトリックを設定する必要があります

80        self.state_modules = [self.accuracy]

モデルを初期化

83        self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)

85        self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)

87        self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)

トレーニングおよび検証ローダー

90        self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
91                                       batch_size=self.batch_size)
92        self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
93                                       batch_size=self.batch_size)

このメソッドは、バッチごとにトレーナーによって呼び出されます。

95    def step(self, batch: Any, batch_idx: BatchIndex):

モデルモードを設定

100        self.model.train(self.mode.is_train)

入力とラベルを取得してモデルのデバイスに移動します

103        data, target = batch[0].to(self.device), batch[1].to(self.device)

トレーニングモードでのインクリメントステップ

106        if self.mode.is_train:
107            tracker.add_global_step(len(data))

モデルを実行

110        p, y_hat, p_sampled, y_hat_sampled = self.model(data)

再構成損失の計算

113        loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
114        tracker.add("loss.", loss_rec)

正則化損失の計算

117        loss_reg = self.loss_reg(p)
118        tracker.add("loss_reg.", loss_reg)

121        loss = loss_rec + self.beta * loss_reg

予想される歩数の計算

124        steps = torch.arange(1, p.shape[0] + 1, device=p.device)
125        expected_steps = (p * steps[:, None]).sum(dim=0)
126        tracker.add("steps.", expected_steps)

通話精度指標

129        self.accuracy(y_hat_sampled > 0, target)
130
131        if self.mode.is_train:

勾配の計算

133            loss.backward()

クリップグラデーション

135            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

オプティマイザー

137            self.optimizer.step()

クリアグラデーション

139            self.optimizer.zero_grad()

141            tracker.save()

実験を実行する

144def main():
148    experiment.create(name='ponder_net')
149
150    conf = Configs()
151    experiment.configs(conf, {
152        'optimizer.optimizer': 'Adam',
153        'optimizer.learning_rate': 0.0003,
154    })
155
156    with experiment.start():
157        conf.run()

160if __name__ == '__main__':
161    main()