Transformer XL Experiment

This is an annotated PyTorch experiment to train a transformer xl model.

11from typing import List
12
13import torch
14import torch.nn as nn
15from labml import experiment, tracker, monit, logger
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
19from labml_nn.helpers.metrics import SimpleStateModule
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer

Auto regressive model

24class AutoregressiveModel(nn.Module):
29    def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
30        super().__init__()

Token embedding module

32        self.src_embed = nn.Embedding(n_vocab, d_model)

Transformer

34        self.transformer = transformer

Final layer

36        self.generator = nn.Linear(d_model, n_vocab)

Masks

38        self.mask_x = None
39        self.mask_mem = None
41    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):

Length of the memory

43        m_len = len(mem[0]) if mem else 0

Create a subsequent mask for tokens

45        if self.mask_x is None or self.mask_x.shape[0] < len(x):
46            from labml_nn.transformers.utils import subsequent_mask
47            self.mask_x = subsequent_mask(len(x)).to(x.device)

Create an all ones (full visibility) mask for memory

49        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
50            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

Concatenate the masks if there is memory

53        if m_len:
54            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

Use the subsequent mask otherwise

56        else:
57            mask = self.mask_x[:len(x), :len(x)]

Token embeddings

60        x = self.src_embed(x)

Run it through the transformer

62        res, mem = self.transformer(x, mem, mask)

Generate logits of the next token

64        res = self.generator(res)

66        return res, mem

Configurations

The default configs can and will be over-ridden when we start the experiment

69class Configs(NLPAutoRegressionConfigs):
76    model: AutoregressiveModel

Token embedding size

79    d_model: int = 128

Number of attention heads

81    heads: int = 4

Dropout probability

83    dropout: float = 0.0

Number of features in FFN hidden layer

85    d_ff: int = 256

Number of transformer layers

87    n_layers: int = 6

Number of memories to keep

89    mem_len: int = 128

State module to maintain memories when switching between training and validation

91    memory = SimpleStateModule()
93    def init(self):

Set tracker configurations

95        tracker.set_scalar("accuracy.*", True)
96        tracker.set_scalar("loss.*", True)

This will keep the accuracy metric stats and memories separate for training and validation.

98        self.state_modules = [self.accuracy, self.memory]

Concatenate memories and remove old memories to keep a maximum of mem_len memories.

100    def merge_memory(self, old_mem, new_mem):

If it's configured not to use memory

107        if self.mem_len == 0:
108            return []

Concatenate with old memory

111        if old_mem:
112            mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
113        else:
114            mem = new_mem

Truncate old memories

117        if len(mem[0]) > self.mem_len:
118            mem = [m[-self.mem_len:] for m in mem]

121        return mem

Training/validation step

123    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

129        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

132        if self.mode.is_train:
133            tracker.add_global_step(data.shape[0] * data.shape[1])

Get memories

136        mem = self.memory.get()

Run the model

138        output, new_mem = self.model(data, mem)

Merge memory

140        mem = self.merge_memory(mem, new_mem)

Update memories

142        self.memory.set(mem)

Calculate and log cross entropy loss

145        loss = self.loss_func(output, target)
146        tracker.add("loss.", loss)

Calculate and log accuracy

149        self.accuracy(output, target)
150        self.accuracy.track()

Train the model

153        if self.mode.is_train:

Calculate gradients

155            loss.backward()

Clip gradients

157            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

159            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

161            if batch_idx.is_last:
162                tracker.add('model', self.model)

Clear the gradients

164            self.optimizer.zero_grad()

Save the tracked metrics

167        tracker.save()

Sampling function to generate samples periodically while training

169    def sample(self):

Starting prompt

175        prompt = self.prompt

Collect output for printing

177        log = [(prompt, Text.subtle)]

memory

179        mem = []

Sample 25 tokens

181        for i in monit.iterate('Sample', 25):

Tokenize the prompt

183            data = self.text.text_to_i(prompt).unsqueeze(-1)

Move to device

185            data = data.to(self.device)

Get the model output

187            output, new_mem = self.model(data, mem)

Get the model prediction (greedy)

189            output = output.argmax(dim=-1).squeeze(1)

Add the prediction to prompt

191            prompt += self.prompt_separator + self.text.itos[output[-1]]

Only feed the last character to model in next iteration, rest will go in as memories

193            prompt = prompt[-1:]

Add the prediction for logging

195            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Update memory

197            mem = self.merge_memory(mem, new_mem)

Print the sampled output

200        logger.log(log)

Initialize the auto-regressive model

203@option(Configs.model)
204def autoregressive_model(c: Configs):
208    from labml_nn.transformers.xl import RelativeMultiHeadAttention
209    from labml_nn.transformers.feed_forward import FeedForward
210    m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
211        TransformerXLLayer(d_model=c.d_model,
212                           self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
213                           feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
214                           dropout_prob=c.dropout), c.n_layers))
215    return m.to(c.device)

Run the experiment

218def main():

Create experiment

223    experiment.create(name="transformer_xl", comment='')

Create configs

225    conf = Configs()

Load configurations

227    experiment.configs(conf,

A dictionary of configurations to override

229                       {'tokenizer': 'character',
230                        'text': 'tiny_shakespeare',
231                        'optimizer.learning_rate': 1.,
232                        'optimizer.optimizer': 'Noam',
233                        'prompt': 'It is',
234                        'prompt_separator': '',
235
236                        'train_loader': 'sequential_train_loader',
237                        'valid_loader': 'sequential_valid_loader',
238
239                        'seq_len': 2,
240                        'mem_len': 32,
241                        'epochs': 128,
242                        'batch_size': 32,
243                        'inner_iterations': 25,
244                        })

Set models for saving and loading

247    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

250    with experiment.start():

TrainValidConfigs.run

252        conf.run()

256if __name__ == '__main__':
257    main()