13from typing import Any
14
15import torch
16from torch import nn
17from torch.utils.data import DataLoader
18
19from labml import tracker, experiment
20from labml_nn.helpers.metrics import AccuracyDirect
21from labml_nn.helpers.trainer import SimpleTrainValidConfigs, BatchIndex
22from labml_nn.adaptive_computation.parity import ParityDataset
23from labml_nn.adaptive_computation.ponder_net import ParityPonderGRU, ReconstructionLoss, RegularizationLossConfigurations with a simple training loop
26class Configs(SimpleTrainValidConfigs):Number of epochs
33    epochs: int = 100Number of batches per epoch
35    n_batches: int = 500Batch size
37    batch_size: int = 128Model
40    model: ParityPonderGRU43    loss_rec: ReconstructionLoss45    loss_reg: RegularizationLossThe number of elements in the input vector. We keep it low for demonstration; otherwise, training takes a lot of time. Although the parity task seems simple, figuring out the pattern by looking at samples is quite hard.
51    n_elems: int = 8Number of units in the hidden layer (state)
53    n_hidden: int = 64Maximum number of steps
55    max_steps: int = 20for the geometric distribution
58    lambda_p: float = 0.2Regularization loss coefficient
60    beta: float = 0.01Gradient clipping by norm
63    grad_norm_clip: float = 1.0Training and validation loaders
66    train_loader: DataLoader
67    valid_loader: DataLoaderAccuracy calculator
70    accuracy = AccuracyDirect()72    def init(self):Print indicators to screen
74        tracker.set_scalar('loss.*', True)
75        tracker.set_scalar('loss_reg.*', True)
76        tracker.set_scalar('accuracy.*', True)
77        tracker.set_scalar('steps.*', True)We need to set the metrics to calculate them for the epoch for training and validation
80        self.state_modules = [self.accuracy]Initialize the model
83        self.model = ParityPonderGRU(self.n_elems, self.n_hidden, self.max_steps).to(self.device)85        self.loss_rec = ReconstructionLoss(nn.BCEWithLogitsLoss(reduction='none')).to(self.device)87        self.loss_reg = RegularizationLoss(self.lambda_p, self.max_steps).to(self.device)Training and validation loaders
90        self.train_loader = DataLoader(ParityDataset(self.batch_size * self.n_batches, self.n_elems),
91                                       batch_size=self.batch_size)
92        self.valid_loader = DataLoader(ParityDataset(self.batch_size * 32, self.n_elems),
93                                       batch_size=self.batch_size)This method gets called by the trainer for each batch
95    def step(self, batch: Any, batch_idx: BatchIndex):Set the model mode
100        self.model.train(self.mode.is_train)Get the input and labels and move them to the model's device
103        data, target = batch[0].to(self.device), batch[1].to(self.device)Increment step in training mode
106        if self.mode.is_train:
107            tracker.add_global_step(len(data))Run the model
110        p, y_hat, p_sampled, y_hat_sampled = self.model(data)Calculate the reconstruction loss
113        loss_rec = self.loss_rec(p, y_hat, target.to(torch.float))
114        tracker.add("loss.", loss_rec)Calculate the regularization loss
117        loss_reg = self.loss_reg(p)
118        tracker.add("loss_reg.", loss_reg)121        loss = loss_rec + self.beta * loss_regCalculate the expected number of steps taken
124        steps = torch.arange(1, p.shape[0] + 1, device=p.device)
125        expected_steps = (p * steps[:, None]).sum(dim=0)
126        tracker.add("steps.", expected_steps)Call accuracy metric
129        self.accuracy(y_hat_sampled > 0, target)
130
131        if self.mode.is_train:Compute gradients
133            loss.backward()Clip gradients
135            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)Optimizer
137            self.optimizer.step()Clear gradients
139            self.optimizer.zero_grad()141            tracker.save()Run the experiment
144def main():148    experiment.create(name='ponder_net')
149
150    conf = Configs()
151    experiment.configs(conf, {
152        'optimizer.optimizer': 'Adam',
153        'optimizer.learning_rate': 0.0003,
154    })
155
156    with experiment.start():
157        conf.run()160if __name__ == '__main__':
161    main()