Auto-regressive NLP model trainer

11from typing import Callable
13import torch
14import torch.nn as nn
15from import DataLoader, RandomSampler
17from labml import lab, monit, logger, tracker
18from labml.configs import option
19from labml.logger import Text
20from labml_helpers.datasets.text import TextDataset, SequentialDataLoader, SequentialUnBatchedDataset, TextFileDataset
21from labml_helpers.device import DeviceConfigs
22from labml_helpers.metrics.accuracy import Accuracy
23from labml_helpers.module import Module
24from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
25from labml_nn.optimizers.configs import OptimizerConfigs

Cross entropy loss

28class CrossEntropyLoss(Module):
33    def __init__(self):
34        super().__init__()
35        self.loss = nn.CrossEntropyLoss()
37    def forward(self, outputs, targets):
38        return self.loss(outputs.view(-1, outputs.shape[-1]), targets.view(-1))

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

41class NLPAutoRegressionConfigs(TrainValidConfigs):


52    optimizer: torch.optim.Adam

Training device

54    device: torch.device = DeviceConfigs()

Autoregressive model

57    model: Module

Text dataset

59    text: TextDataset

Batch size

61    batch_size: int = 16

Length of the sequence, or context size

63    seq_len: int = 512

Number of token in vocabulary

65    n_tokens: int


67    tokenizer: Callable = 'character'

Text prompt to start sampling (for illustration)

70    prompt: str

The token separator when sampling (blank for character level tokenization)

72    prompt_separator: str

Whether to periodically save models

75    is_save_models = True

Loss function

78    loss_func = CrossEntropyLoss()

Accuracy function

80    accuracy = Accuracy()

Model embedding size

82    d_model: int = 512

Gradient clipping

84    grad_norm_clip: float = 1.0

Training data loader

87    train_loader: DataLoader = 'shuffled_train_loader'

Validation data loader

89    valid_loader: DataLoader = 'shuffled_valid_loader'

Data loaders shuffle with replacement

92    dataloader_shuffle_with_replacement: bool = False

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

97    is_log_model_params_grads: bool = False

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

102    is_log_model_activations: bool = False


104    def init(self):

Set tracker configurations

109        tracker.set_scalar("accuracy.*", True)
110        tracker.set_scalar("loss.*", True)
111        tracker.set_text("sampled", False)

Add a hook to log module outputs

113        hook_model_outputs(self.mode, self.model, 'model')

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

118        self.state_modules = [self.accuracy]

Override to calculate and log other metrics

120    def other_metrics(self, output: torch.Tensor, target: torch.Tensor):
122        pass

Training or validation step

124    def step(self, batch: any, batch_idx: BatchIndex):

Set training/eval mode

130        self.model.train(self.mode.is_train)

Move data to the device

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

140        with self.mode.update(is_log_activations=batch_idx.is_last and self.is_log_model_activations):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

144            output, *_ = self.model(data)

Calculate and log loss

147        loss = self.loss_func(output, target)
148        tracker.add("loss.", loss)

Calculate and log accuracy

151        self.accuracy(output, target)
152        self.accuracy.track()
154        self.other_metrics(output, target)

Train the model

157        if self.mode.is_train:

Calculate gradients

159            loss.backward()

Clip gradients

161            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

163            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

165            if batch_idx.is_last and self.is_log_model_params_grads:
166                tracker.add('model', self.model)

Clear the gradients

168            self.optimizer.zero_grad()

Save the tracked metrics


Sampling function to generate samples periodically while training

173    def sample(self):

Starting prompt

179        prompt = self.prompt

Collect output for printing

181        log = [(prompt, Text.subtle)]

Sample 25 tokens

183        for i in monit.iterate('Sample', 25):

Tokenize the prompt

185            data = self.text.text_to_i(prompt).unsqueeze(-1)
186            data =

Get the model output

188            output, *_ = self.model(data)

Get the model prediction (greedy)

190            output = output.argmax(dim=-1).squeeze()

Add the prediction to prompt

192            prompt += self.prompt_separator + self.text.itos[output[-1]]

Add the prediction for logging

194            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
196        tracker.add({'sampled': prompt})

Print the sampled output

198        logger.log(log)
202def _optimizer(c: NLPAutoRegressionConfigs):
207    optimizer = OptimizerConfigs()
208    optimizer.parameters = c.model.parameters()
209    optimizer.optimizer = 'Adam'
210    optimizer.d_model = c.d_model
212    return optimizer

Get number of tokens

216def _n_tokens(c: NLPAutoRegressionConfigs):
220    return c.text.n_tokens

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

224def basic_english():
238    from import get_tokenizer
239    return get_tokenizer('basic_english')

Character level tokenizer

242def character_tokenizer(x: str):
246    return list(x)

Character level tokenizer configuration

250def character():
254    return character_tokenizer

Tiny Shakespeare dataset

It will download from the url if not present

258def tiny_shakespeare(c: NLPAutoRegressionConfigs):
264    return TextFileDataset(
265        lab.get_data_path() / 'tiny_shakespeare.txt',
266        c.tokenizer,
267        url='')

Sequential training data loader

271def sequential_train_loader(c: NLPAutoRegressionConfigs):
275    return SequentialDataLoader(text=c.text.train,
276                                dataset=c.text,
277                                batch_size=c.batch_size,
278                                seq_len=c.seq_len)

Sequential validation data loader

282def sequential_valid_loader(c: NLPAutoRegressionConfigs):
286    return SequentialDataLoader(text=c.text.valid,
287                                dataset=c.text,
288                                batch_size=c.batch_size,
289                                seq_len=c.seq_len)

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

292def transpose_batch(batch):
300    transposed_data = list(zip(*batch))

Stack the batch along the second dimension dim=1

302    src = torch.stack(transposed_data[0], dim=1)
303    tgt = torch.stack(transposed_data[1], dim=1)
305    return src, tgt

Shuffled training data loader

309def shuffled_train_loader(c: NLPAutoRegressionConfigs):
313    dataset = SequentialUnBatchedDataset(text=c.text.train,
314                                         dataset=c.text,
315                                         seq_len=c.seq_len)
316    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
318    return DataLoader(dataset,
319                      batch_size=c.batch_size,
320                      collate_fn=transpose_batch,
321                      sampler=sampler)

Shuffled validation data loader

325def shuffled_valid_loader(c: NLPAutoRegressionConfigs):
329    dataset = SequentialUnBatchedDataset(text=c.text.valid,
330                                         dataset=c.text,
331                                         seq_len=c.seq_len)
332    sampler = RandomSampler(dataset, replacement=c.dataloader_shuffle_with_replacement)
334    return DataLoader(dataset,
335                      batch_size=c.batch_size,
336                      collate_fn=transpose_batch,
337                      sampler=sampler)