Transformer XL Experiment

This is an annotated PyTorch experiment to train a transformer xl model.

11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

Auto regressive model

26class AutoregressiveModel(Module):
31    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32        super().__init__()

Token embedding module

34        self.src_embed = nn.Embedding(n_vocab, d_model)

Transformer

36        self.transformer = transformer

Final layer

38        self.generator = nn.Linear(d_model, n_vocab)

Masks

40        self.mask_x = None
41        self.mask_mem = None
43    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):

Length of the memory

45        m_len = len(mem[0]) if mem else 0

Create a subsequent mask for tokens

47        if self.mask_x is None or self.mask_x.shape[0] < len(x):
48            from labml_nn.transformers.utils import subsequent_mask
49            self.mask_x = subsequent_mask(len(x)).to(x.device)

Create an all ones (full visibility) mask for memory

51        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

Concatenate the masks if there is memory

55        if m_len:
56            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

Use the subsequent mask otherwise

58        else:
59            mask = self.mask_x[:len(x), :len(x)]

Token embeddings

62        x = self.src_embed(x)

Run it through the transformer

64        res, mem = self.transformer(x, mem, mask)

Generate logits of the next token

66        res = self.generator(res)

68        return res, mem

Configurations

The default configs can and will be over-ridden when we start the experiment

71class Configs(NLPAutoRegressionConfigs):
78    model: AutoregressiveModel

Token embedding size

81    d_model: int = 128

Number of attention heads

83    heads: int = 4

Dropout probability

85    dropout: float = 0.0

Number of features in FFN hidden layer

87    d_ff: int = 256

Number of transformer layers

89    n_layers: int = 6

Number of memories to keep

91    mem_len: int = 128

State module to maintain memories when switching between training and validation

93    memory = SimpleStateModule()
95    def init(self):

Set tracker configurations

97        tracker.set_scalar("accuracy.*", True)
98        tracker.set_scalar("loss.*", True)

Add a hook to log module outputs

100        hook_model_outputs(self.mode, self.model, 'model')

This will keep the accuracy metric stats and memories separate for training and validation.

102        self.state_modules = [self.accuracy, self.memory]

Concatenate memories and remove old memories to keep a maximum of mem_len memories.

104    def merge_memory(self, old_mem, new_mem):

If it's configured not to use memory

111        if self.mem_len == 0:
112            return []

Concatenate with old memory

115        if old_mem:
116            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117        else:
118            mem = new_mem

Truncate old memories

121        if len(mem[0]) > self.mem_len:
122            mem = [m[-self.mem_len:] for m in mem]

125        return mem

Training/validation step

127    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

133        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

136        if self.mode.is_train:
137            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

140        with self.mode.update(is_log_activations=batch_idx.is_last):

Get memories

142            mem = self.memory.get()

Run the model

144            output, new_mem = self.model(data, mem)

Merge memory

146            mem = self.merge_memory(mem, new_mem)

Update memories

148            self.memory.set(mem)

Calculate and log cross entropy loss

151        loss = self.loss_func(output, target)
152        tracker.add("loss.", loss)

Calculate and log accuracy

155        self.accuracy(output, target)
156        self.accuracy.track()

Train the model

159        if self.mode.is_train:

Calculate gradients

161            loss.backward()

Clip gradients

163            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

165            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

Clear the gradients

170            self.optimizer.zero_grad()

Save the tracked metrics

173        tracker.save()

Sampling function to generate samples periodically while training

175    def sample(self):

Starting prompt

181        prompt = self.prompt

Collect output for printing

183        log = [(prompt, Text.subtle)]

memory

185        mem = []

Sample 25 tokens

187        for i in monit.iterate('Sample', 25):

Tokenize the prompt

189            data = self.text.text_to_i(prompt).unsqueeze(-1)

Move to device

191            data = data.to(self.device)

Get the model output

193            output, new_mem = self.model(data, mem)

Get the model prediction (greedy)

195            output = output.argmax(dim=-1).squeeze(1)

Add the prediction to prompt

197            prompt += self.prompt_separator + self.text.itos[output[-1]]

Only feed the last character to model in next iteration, rest will go in as memories

199            prompt = prompt[-1:]

Add the prediction for logging

201            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Update memory

203            mem = self.merge_memory(mem, new_mem)

Print the sampled output

206        logger.log(log)

Initialize the auto-regressive model

209@option(Configs.model)
210def autoregressive_model(c: Configs):
214    from labml_nn.transformers.xl import RelativeMultiHeadAttention
215    from labml_nn.transformers.feed_forward import FeedForward
216    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217        TransformerXLLayer(d_model=c.d_model,
218                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220                           dropout_prob=c.dropout), c.n_layers))
221    return m.to(c.device)

Run the experiment

224def main():

Create experiment

229    experiment.create(name="transformer_xl", comment='')

Create configs

231    conf = Configs()

Load configurations

233    experiment.configs(conf,

A dictionary of configurations to override

235                       {'tokenizer': 'character',
236                        'text': 'tiny_shakespeare',
237                        'optimizer.learning_rate': 1.,
238                        'optimizer.optimizer': 'Noam',
239                        'prompt': 'It is',
240                        'prompt_separator': '',
241
242                        'train_loader': 'sequential_train_loader',
243                        'valid_loader': 'sequential_valid_loader',
244
245                        'seq_len': 2,
246                        'mem_len': 32,
247                        'epochs': 128,
248                        'batch_size': 32,
249                        'inner_iterations': 25,
250                        })

Set models for saving and loading

253    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

256    with experiment.start():

TrainValidConfigs.run

258        conf.run()

262if __name__ == '__main__':
263    main()