This is an annotated PyTorch experiment to train a compressive transformer model.
11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24 CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28 mem: List[torch.Tensor]
29 c_mem: List[torch.Tensor]
32class AutoregressiveModel(Module):
37 def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38 super().__init__()
Token embedding module
40 self.src_embed = nn.Embedding(n_vocab, d_model)
Transformer
42 self.transformer = transformer
Final layer
44 self.generator = nn.Linear(d_model, n_vocab)
Masks
46 self.mask_x = None
47 self.mask_mem = None
49 def forward(self, x: torch.Tensor, mem: CompressedMemory):
Get memory and compressed memory
51 if mem is not None:
52 mem, c_mem = mem.mem, mem.c_mem
53 else:
54 mem = []
55 c_mem = []
Total length of the memory and compressed memory (for masks)
58 m_len = len(mem[0]) if mem else 0
59 if c_mem:
60 m_len += len(c_mem[0])
Create a subsequent mask for tokens
63 if self.mask_x is None or self.mask_x.shape[0] < len(x):
64 from labml_nn.transformers.utils import subsequent_mask
65 self.mask_x = subsequent_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
67 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
Concatenate the masks if there is memory
71 if m_len:
72 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
Use only the subsequent mask otherwise
74 else:
75 mask = self.mask_x[:len(x), :len(x)]
Token embeddings
78 x = self.src_embed(x)
Run it through the transformer
80 res, mem = self.transformer(x, mem, c_mem, mask)
Generate logits of the next token
82 res = self.generator(res)
84 return res, mem
The default configurations can and will be overridden when we start the experiment.
87class Configs(NLPAutoRegressionConfigs):
94 model: AutoregressiveModel
Token embedding size
97 d_model: int = 128
Number of attention heads
99 heads: int = 4
Dropout probability
101 dropout: float = 0.0
Number of features in FFN hidden layer
103 d_ff: int = 256
Number of transformer layers
105 n_layers: int = 6
Number of memories to keep
107 mem_len: int = 8
State module to maintain memories when switching between training and validation
109 memory = SimpleStateModule()
Attention Reconstruction Loss
111 attention_reconstruction_loss: AttentionReconstructionLoss
Compression rate
113 compression_rate: int = 4
Compressed memory length
115 c_mem_len: int = 128
117 def init(self):
Set tracker configurations
119 tracker.set_scalar("accuracy.*", True)
120 tracker.set_scalar("loss.*", True)
Do not print the attention reconstruction loss in the terminal
122 tracker.set_scalar("ar_loss.*", False)
Add a hook to log module outputs
124 hook_model_outputs(self.mode, self.model, 'model')
This will keep the accuracy metric stats and memories separate for training and validation.
126 self.state_modules = [self.accuracy, self.memory]
Concatenate new memories and compress the oldest memories.
128 @torch.no_grad()
129 def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130 -> Tuple[CompressedMemory, List[torch.Tensor]]:
If the configurations specify not to use memory
136 if self.mem_len == 0 and self.c_mem_len == 0:
137 return CompressedMemory([], []), []
Get memory and compressed memory
140 if mem is not None:
141 mem, c_mem = mem.mem, mem.c_mem
142 else:
143 mem, c_mem = [], []
Concatenate new memories with old memory
146 if mem:
147 mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148 else:
149 mem = new_mem
Compress the oldest memories if there are more memories than mem_len
152 if len(mem[0]) > self.mem_len:
Calculate the number of compressed memories to make , where is the number of memories we have and is the maximum number of memories we maintain (mem_len
).
156 n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
Number of memories to compress
158 n_old = n_c_mem * self.compression_rate
A list to keep memories that need to be compressed for each layer.
160 mem_to_compress = []
A list to keep the memories that do not get compressed for each layer.
162 uncompressed_mem = []
Iterate through memories of each layer.
164 for m in mem:
Split the memories at
166 cm, m = torch.split(m, [n_old, len(m) - n_old])
Collect memories to compress
168 mem_to_compress.append(cm)
Collect remaining memories
170 uncompressed_mem.append(m)
Update the memories
172 mem = uncompressed_mem
Compress the memories
175 new_c_mem = []
176 for i, layer in enumerate(self.model.transformer.layers):
177 new_c_mem.append(layer.compress(mem_to_compress[i]))
Concatenate newly compressed memories with old compressed memories
180 if c_mem:
181 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
If there are no old compressed memories
183 else:
184 c_mem = new_c_mem
Truncate old memories
187 if len(c_mem[0]) > self.c_mem_len:
188 c_mem = [m[-self.c_mem_len:] for m in c_mem]
No memories are compressed if the number of memories is less than mem_len
190 else:
191 mem_to_compress = []
Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.
195 return CompressedMemory(mem, c_mem), mem_to_compress
197 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
203 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
206 if self.mode.is_train:
207 tracker.add_global_step(data.shape[0] * data.shape[1])
Whether to capture model outputs
210 with self.mode.update(is_log_activations=batch_idx.is_last):
Get memories
212 mem = self.memory.get()
Run the model
214 output, new_mem = self.model(data, mem)
Merge and compress memory
216 mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
Update memories
218 self.memory.set(mem)
Calculate and log cross entropy loss
221 loss = self.loss_func(output, target)
222 tracker.add("loss.", loss)
Calculate attention reconstruction loss if memories were compressed in this step
225 if mem_to_compress:
Get attention reconstruction loss
227 ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
Track attention reconstruction loss
229 tracker.add("ar_loss.", ar_loss)
Add attention reconstruction loss to loss
231 loss = loss + ar_loss
Calculate and log accuracy
234 self.accuracy(output, target)
235 self.accuracy.track()
Train the model
238 if self.mode.is_train:
Calculate gradients
240 loss.backward()
Clip gradients
242 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
244 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
246 if batch_idx.is_last:
247 tracker.add('model', self.model)
Clear the gradients
249 self.optimizer.zero_grad()
Save the tracked metrics
252 tracker.save()
254 def sample(self):
Starting prompt
260 prompt = self.prompt
Collect output for printing
262 log = [(prompt, Text.subtle)]
memory
264 mem = CompressedMemory([], [])
Sample 25 tokens
266 for i in monit.iterate('Sample', 25):
Tokenize the prompt
268 data = self.text.text_to_i(prompt).unsqueeze(-1)
Move to device
270 data = data.to(self.device)
Get the model output
272 output, new_mem = self.model(data, mem)
Get the model prediction (greedy)
274 output = output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
276 prompt += self.prompt_separator + self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
278 prompt = prompt[-1:]
Add the prediction for logging
280 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Update and compress memory
282 mem, _ = self.merge_compress_memory(mem, new_mem)
Print the sampled output
285 logger.log(log)
288@option(Configs.model)
289def autoregressive_model(c: Configs):
293 from labml_nn.transformers.xl import RelativeMultiHeadAttention
294 from labml_nn.transformers.feed_forward import FeedForward
295 m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296 CompressiveTransformerLayer(d_model=c.d_model,
297 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299 dropout_prob=c.dropout,
300 compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301 return m.to(c.device)
304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309 return AttentionReconstructionLoss(c.model.transformer.layers)
312def main():
Create experiment
317 experiment.create(name="compressive_transformer", comment='')
Create configs
319 conf = Configs()
Load configurations
321 experiment.configs(conf,
A dictionary of configurations to override
323 {'tokenizer': 'character',
324 'text': 'tiny_shakespeare',
325 'optimizer.learning_rate': 2.5e-4,
326 'optimizer.optimizer': 'AdamW',
327 'prompt': 'It is',
328 'prompt_separator': '',
329
330 'train_loader': 'sequential_train_loader',
331 'valid_loader': 'sequential_valid_loader',
332
333 'seq_len': 8,
334 'mem_len': 8,
335 'epochs': 128,
336 'batch_size': 32,
337 'inner_iterations': 25,
338 'compression_rate': 2,
339 })
Set models for saving and loading
342 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
345 with experiment.start():
TrainValidConfigs.run
347 conf.run()
351if __name__ == '__main__':
352 main()