This is an annotated PyTorch experiment to train a transformer xl model.
11from typing import List
12
13import torch
14import torch.nn as nn
15from labml import experiment, tracker, monit, logger
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
19from labml_nn.helpers.metrics import SimpleStateModule
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
24class AutoregressiveModel(nn.Module):
29 def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
30 super().__init__()
Token embedding module
32 self.src_embed = nn.Embedding(n_vocab, d_model)
Transformer
34 self.transformer = transformer
Final layer
36 self.generator = nn.Linear(d_model, n_vocab)
Masks
38 self.mask_x = None
39 self.mask_mem = None
41 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
Length of the memory
43 m_len = len(mem[0]) if mem else 0
Create a subsequent mask for tokens
45 if self.mask_x is None or self.mask_x.shape[0] < len(x):
46 from labml_nn.transformers.utils import subsequent_mask
47 self.mask_x = subsequent_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
49 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
50 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
Concatenate the masks if there is memory
53 if m_len:
54 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
Use the subsequent mask otherwise
56 else:
57 mask = self.mask_x[:len(x), :len(x)]
Token embeddings
60 x = self.src_embed(x)
Run it through the transformer
62 res, mem = self.transformer(x, mem, mask)
Generate logits of the next token
64 res = self.generator(res)
66 return res, mem
69class Configs(NLPAutoRegressionConfigs):
76 model: AutoregressiveModel
Token embedding size
79 d_model: int = 128
Number of attention heads
81 heads: int = 4
Dropout probability
83 dropout: float = 0.0
Number of features in FFN hidden layer
85 d_ff: int = 256
Number of transformer layers
87 n_layers: int = 6
Number of memories to keep
89 mem_len: int = 128
State module to maintain memories when switching between training and validation
91 memory = SimpleStateModule()
93 def init(self):
Set tracker configurations
95 tracker.set_scalar("accuracy.*", True)
96 tracker.set_scalar("loss.*", True)
This will keep the accuracy metric stats and memories separate for training and validation.
98 self.state_modules = [self.accuracy, self.memory]
Concatenate memories and remove old memories to keep a maximum of mem_len
memories.
100 def merge_memory(self, old_mem, new_mem):
If it's configured not to use memory
107 if self.mem_len == 0:
108 return []
Concatenate with old memory
111 if old_mem:
112 mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
113 else:
114 mem = new_mem
Truncate old memories
117 if len(mem[0]) > self.mem_len:
118 mem = [m[-self.mem_len:] for m in mem]
121 return mem
123 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
129 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
132 if self.mode.is_train:
133 tracker.add_global_step(data.shape[0] * data.shape[1])
Get memories
136 mem = self.memory.get()
Run the model
138 output, new_mem = self.model(data, mem)
Merge memory
140 mem = self.merge_memory(mem, new_mem)
Update memories
142 self.memory.set(mem)
Calculate and log cross entropy loss
145 loss = self.loss_func(output, target)
146 tracker.add("loss.", loss)
Calculate and log accuracy
149 self.accuracy(output, target)
150 self.accuracy.track()
Train the model
153 if self.mode.is_train:
Calculate gradients
155 loss.backward()
Clip gradients
157 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
159 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
161 if batch_idx.is_last:
162 tracker.add('model', self.model)
Clear the gradients
164 self.optimizer.zero_grad()
Save the tracked metrics
167 tracker.save()
169 def sample(self):
Starting prompt
175 prompt = self.prompt
Collect output for printing
177 log = [(prompt, Text.subtle)]
memory
179 mem = []
Sample 25 tokens
181 for i in monit.iterate('Sample', 25):
Tokenize the prompt
183 data = self.text.text_to_i(prompt).unsqueeze(-1)
Move to device
185 data = data.to(self.device)
Get the model output
187 output, new_mem = self.model(data, mem)
Get the model prediction (greedy)
189 output = output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
191 prompt += self.prompt_separator + self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
193 prompt = prompt[-1:]
Add the prediction for logging
195 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Update memory
197 mem = self.merge_memory(mem, new_mem)
Print the sampled output
200 logger.log(log)
203@option(Configs.model)
204def autoregressive_model(c: Configs):
208 from labml_nn.transformers.xl import RelativeMultiHeadAttention
209 from labml_nn.transformers.feed_forward import FeedForward
210 m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
211 TransformerXLLayer(d_model=c.d_model,
212 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
213 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
214 dropout_prob=c.dropout), c.n_layers))
215 return m.to(c.device)
218def main():
Create experiment
223 experiment.create(name="transformer_xl", comment='')
Create configs
225 conf = Configs()
Load configurations
227 experiment.configs(conf,
A dictionary of configurations to override
229 {'tokenizer': 'character',
230 'text': 'tiny_shakespeare',
231 'optimizer.learning_rate': 1.,
232 'optimizer.optimizer': 'Noam',
233 'prompt': 'It is',
234 'prompt_separator': '',
235
236 'train_loader': 'sequential_train_loader',
237 'valid_loader': 'sequential_valid_loader',
238
239 'seq_len': 2,
240 'mem_len': 32,
241 'epochs': 128,
242 'batch_size': 32,
243 'inner_iterations': 25,
244 })
Set models for saving and loading
247 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
250 with experiment.start():
TrainValidConfigs.run
252 conf.run()
256if __name__ == '__main__':
257 main()