This is an annotated PyTorch experiment to train a transformer xl model.
11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
26class AutoregressiveModel(Module):
31 def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32 super().__init__()
Token embedding module
34 self.src_embed = nn.Embedding(n_vocab, d_model)
Transformer
36 self.transformer = transformer
Final layer
38 self.generator = nn.Linear(d_model, n_vocab)
Masks
40 self.mask_x = None
41 self.mask_mem = None
43 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
Length of the memory
45 m_len = len(mem[0]) if mem else 0
Create a subsequent mask for tokens
47 if self.mask_x is None or self.mask_x.shape[0] < len(x):
48 from labml_nn.transformers.utils import subsequent_mask
49 self.mask_x = subsequent_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
51 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
Concatenate the masks if there is memory
55 if m_len:
56 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
Use the subsequent mask otherwise
58 else:
59 mask = self.mask_x[:len(x), :len(x)]
Token embeddings
62 x = self.src_embed(x)
Run it through the transformer
64 res, mem = self.transformer(x, mem, mask)
Generate logits of the next token
66 res = self.generator(res)
68 return res, mem
71class Configs(NLPAutoRegressionConfigs):
78 model: AutoregressiveModel
Token embedding size
81 d_model: int = 128
Number of attention heads
83 heads: int = 4
Dropout probability
85 dropout: float = 0.0
Number of features in FFN hidden layer
87 d_ff: int = 256
Number of transformer layers
89 n_layers: int = 6
Number of memories to keep
91 mem_len: int = 128
State module to maintain memories when switching between training and validation
93 memory = SimpleStateModule()
95 def init(self):
Set tracker configurations
97 tracker.set_scalar("accuracy.*", True)
98 tracker.set_scalar("loss.*", True)
Add a hook to log module outputs
100 hook_model_outputs(self.mode, self.model, 'model')
This will keep the accuracy metric stats and memories separate for training and validation.
102 self.state_modules = [self.accuracy, self.memory]
Concatenate memories and remove old memories to keep a maximum of mem_len
memories.
104 def merge_memory(self, old_mem, new_mem):
If it's configured not to use memory
111 if self.mem_len == 0:
112 return []
Concatenate with old memory
115 if old_mem:
116 mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117 else:
118 mem = new_mem
Truncate old memories
121 if len(mem[0]) > self.mem_len:
122 mem = [m[-self.mem_len:] for m in mem]
125 return mem
127 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
140 with self.mode.update(is_log_activations=batch_idx.is_last):
Get memories
142 mem = self.memory.get()
Run the model
144 output, new_mem = self.model(data, mem)
Merge memory
146 mem = self.merge_memory(mem, new_mem)
Update memories
148 self.memory.set(mem)
Calculate and log cross entropy loss
151 loss = self.loss_func(output, target)
152 tracker.add("loss.", loss)
Calculate and log accuracy
155 self.accuracy(output, target)
156 self.accuracy.track()
Train the model
159 if self.mode.is_train:
Calculate gradients
161 loss.backward()
Clip gradients
163 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
165 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
167 if batch_idx.is_last:
168 tracker.add('model', self.model)
Clear the gradients
170 self.optimizer.zero_grad()
Save the tracked metrics
173 tracker.save()
175 def sample(self):
Starting prompt
181 prompt = self.prompt
Collect output for printing
183 log = [(prompt, Text.subtle)]
memory
185 mem = []
Sample 25 tokens
187 for i in monit.iterate('Sample', 25):
Tokenize the prompt
189 data = self.text.text_to_i(prompt).unsqueeze(-1)
Move to device
191 data = data.to(self.device)
Get the model output
193 output, new_mem = self.model(data, mem)
Get the model prediction (greedy)
195 output = output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
197 prompt += self.prompt_separator + self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
199 prompt = prompt[-1:]
Add the prediction for logging
201 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Update memory
203 mem = self.merge_memory(mem, new_mem)
Print the sampled output
206 logger.log(log)
209@option(Configs.model)
210def autoregressive_model(c: Configs):
214 from labml_nn.transformers.xl import RelativeMultiHeadAttention
215 from labml_nn.transformers.feed_forward import FeedForward
216 m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217 TransformerXLLayer(d_model=c.d_model,
218 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220 dropout_prob=c.dropout), c.n_layers))
221 return m.to(c.device)
224def main():
Create experiment
229 experiment.create(name="transformer_xl", comment='')
Create configs
231 conf = Configs()
Load configurations
233 experiment.configs(conf,
A dictionary of configurations to override
235 {'tokenizer': 'character',
236 'text': 'tiny_shakespeare',
237 'optimizer.learning_rate': 1.,
238 'optimizer.optimizer': 'Noam',
239 'prompt': 'It is',
240 'prompt_separator': '',
241
242 'train_loader': 'sequential_train_loader',
243 'valid_loader': 'sequential_valid_loader',
244
245 'seq_len': 2,
246 'mem_len': 32,
247 'epochs': 128,
248 'batch_size': 32,
249 'inner_iterations': 25,
250 })
Set models for saving and loading
253 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
256 with experiment.start():
TrainValidConfigs.run
258 conf.run()
262if __name__ == '__main__':
263 main()