This is an annotated PyTorch experiment to train a compressive transformer model.
11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15from labml import experiment, tracker, monit, logger
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
19from labml_nn.helpers.metrics import SimpleStateModule
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
22 CompressiveTransformerLayer, Conv1dCompression
25class CompressedMemory(NamedTuple):
26 mem: List[torch.Tensor]
27 c_mem: List[torch.Tensor]
30class AutoregressiveModel(nn.Module):
35 def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
36 super().__init__()
Token embedding module
38 self.src_embed = nn.Embedding(n_vocab, d_model)
Transformer
40 self.transformer = transformer
Final layer
42 self.generator = nn.Linear(d_model, n_vocab)
Masks
44 self.mask_x = None
45 self.mask_mem = None
47 def forward(self, x: torch.Tensor, mem: CompressedMemory):
Get memory and compressed memory
49 if mem is not None:
50 mem, c_mem = mem.mem, mem.c_mem
51 else:
52 mem = []
53 c_mem = []
Total length of the memory and compressed memory (for masks)
56 m_len = len(mem[0]) if mem else 0
57 if c_mem:
58 m_len += len(c_mem[0])
Create a subsequent mask for tokens
61 if self.mask_x is None or self.mask_x.shape[0] < len(x):
62 from labml_nn.transformers.utils import subsequent_mask
63 self.mask_x = subsequent_mask(len(x)).to(x.device)
Create an all ones (full visibility) mask for memory
65 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
66 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
Concatenate the masks if there is memory
69 if m_len:
70 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
Use only the subsequent mask otherwise
72 else:
73 mask = self.mask_x[:len(x), :len(x)]
Token embeddings
76 x = self.src_embed(x)
Run it through the transformer
78 res, mem = self.transformer(x, mem, c_mem, mask)
Generate logits of the next token
80 res = self.generator(res)
82 return res, mem
The default configurations can and will be overridden when we start the experiment.
85class Configs(NLPAutoRegressionConfigs):
92 model: AutoregressiveModel
Token embedding size
95 d_model: int = 128
Number of attention heads
97 heads: int = 4
Dropout probability
99 dropout: float = 0.0
Number of features in FFN hidden layer
101 d_ff: int = 256
Number of transformer layers
103 n_layers: int = 6
Number of memories to keep
105 mem_len: int = 8
State module to maintain memories when switching between training and validation
107 memory = SimpleStateModule()
Attention Reconstruction Loss
109 attention_reconstruction_loss: AttentionReconstructionLoss
Compression rate
111 compression_rate: int = 4
Compressed memory length
113 c_mem_len: int = 128
115 def init(self):
Set tracker configurations
117 tracker.set_scalar("accuracy.*", True)
118 tracker.set_scalar("loss.*", True)
Do not print the attention reconstruction loss in the terminal
120 tracker.set_scalar("ar_loss.*", False)
This will keep the accuracy metric stats and memories separate for training and validation.
122 self.state_modules = [self.accuracy, self.memory]
Concatenate new memories and compress the oldest memories.
124 @torch.no_grad()
125 def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
126 -> Tuple[CompressedMemory, List[torch.Tensor]]:
If the configurations specify not to use memory
132 if self.mem_len == 0 and self.c_mem_len == 0:
133 return CompressedMemory([], []), []
Get memory and compressed memory
136 if mem is not None:
137 mem, c_mem = mem.mem, mem.c_mem
138 else:
139 mem, c_mem = [], []
Concatenate new memories with old memory
142 if mem:
143 mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
144 else:
145 mem = new_mem
Compress the oldest memories if there are more memories than mem_len
148 if len(mem[0]) > self.mem_len:
Calculate the number of compressed memories to make , where is the number of memories we have and is the maximum number of memories we maintain (mem_len
).
152 n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
Number of memories to compress
154 n_old = n_c_mem * self.compression_rate
A list to keep memories that need to be compressed for each layer.
156 mem_to_compress = []
A list to keep the memories that do not get compressed for each layer.
158 uncompressed_mem = []
Iterate through memories of each layer.
160 for m in mem:
Split the memories at
162 cm, m = torch.split(m, [n_old, len(m) - n_old])
Collect memories to compress
164 mem_to_compress.append(cm)
Collect remaining memories
166 uncompressed_mem.append(m)
Update the memories
168 mem = uncompressed_mem
Compress the memories
171 new_c_mem = []
172 for i, layer in enumerate(self.model.transformer.layers):
173 new_c_mem.append(layer.compress(mem_to_compress[i]))
Concatenate newly compressed memories with old compressed memories
176 if c_mem:
177 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
If there are no old compressed memories
179 else:
180 c_mem = new_c_mem
Truncate old memories
183 if len(c_mem[0]) > self.c_mem_len:
184 c_mem = [m[-self.c_mem_len:] for m in c_mem]
No memories are compressed if the number of memories is less than mem_len
186 else:
187 mem_to_compress = []
Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.
191 return CompressedMemory(mem, c_mem), mem_to_compress
193 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
199 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
202 if self.mode.is_train:
203 tracker.add_global_step(data.shape[0] * data.shape[1])
Get memories
206 mem = self.memory.get()
Run the model
208 output, new_mem = self.model(data, mem)
Merge and compress memory
210 mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
Update memories
212 self.memory.set(mem)
Calculate and log cross entropy loss
215 loss = self.loss_func(output, target)
216 tracker.add("loss.", loss)
Calculate attention reconstruction loss if memories were compressed in this step
219 if mem_to_compress:
Get attention reconstruction loss
221 ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
Track attention reconstruction loss
223 tracker.add("ar_loss.", ar_loss)
Add attention reconstruction loss to loss
225 loss = loss + ar_loss
Calculate and log accuracy
228 self.accuracy(output, target)
229 self.accuracy.track()
Train the model
232 if self.mode.is_train:
Calculate gradients
234 loss.backward()
Clip gradients
236 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
238 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
240 if batch_idx.is_last:
241 tracker.add('model', self.model)
Clear the gradients
243 self.optimizer.zero_grad()
Save the tracked metrics
246 tracker.save()
248 def sample(self):
Starting prompt
254 prompt = self.prompt
Collect output for printing
256 log = [(prompt, Text.subtle)]
memory
258 mem = CompressedMemory([], [])
Sample 25 tokens
260 for i in monit.iterate('Sample', 25):
Tokenize the prompt
262 data = self.text.text_to_i(prompt).unsqueeze(-1)
Move to device
264 data = data.to(self.device)
Get the model output
266 output, new_mem = self.model(data, mem)
Get the model prediction (greedy)
268 output = output.argmax(dim=-1).squeeze(1)
Add the prediction to prompt
270 prompt += self.prompt_separator + self.text.itos[output[-1]]
Only feed the last character to model in next iteration, rest will go in as memories
272 prompt = prompt[-1:]
Add the prediction for logging
274 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
Update and compress memory
276 mem, _ = self.merge_compress_memory(mem, new_mem)
Print the sampled output
279 logger.log(log)
282@option(Configs.model)
283def autoregressive_model(c: Configs):
287 from labml_nn.transformers.xl import RelativeMultiHeadAttention
288 from labml_nn.transformers.feed_forward import FeedForward
289 m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
290 CompressiveTransformerLayer(d_model=c.d_model,
291 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
292 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
293 dropout_prob=c.dropout,
294 compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
295 return m.to(c.device)
298@option(Configs.attention_reconstruction_loss)
299def attention_reconstruction_loss(c: Configs):
303 return AttentionReconstructionLoss(c.model.transformer.layers)
306def main():
Create experiment
311 experiment.create(name="compressive_transformer", comment='')
Create configs
313 conf = Configs()
Load configurations
315 experiment.configs(conf,
A dictionary of configurations to override
317 {'tokenizer': 'character',
318 'text': 'tiny_shakespeare',
319 'optimizer.learning_rate': 2.5e-4,
320 'optimizer.optimizer': 'AdamW',
321 'prompt': 'It is',
322 'prompt_separator': '',
323
324 'train_loader': 'sequential_train_loader',
325 'valid_loader': 'sequential_valid_loader',
326
327 'seq_len': 8,
328 'mem_len': 8,
329 'epochs': 128,
330 'batch_size': 32,
331 'inner_iterations': 25,
332 'compression_rate': 2,
333 })
Set models for saving and loading
336 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
339 with experiment.start():
TrainValidConfigs.run
341 conf.run()
345if __name__ == '__main__':
346 main()