Compressive Transformer Experiment

This is an annotated PyTorch experiment to train a compressive transformer model.

11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24    CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28    mem: List[torch.Tensor]
29    c_mem: List[torch.Tensor]

Auto regressive model

32class AutoregressiveModel(Module):
37    def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38        super().__init__()

Token embedding module

40        self.src_embed = nn.Embedding(n_vocab, d_model)

Transformer

42        self.transformer = transformer

Final layer

44        self.generator = nn.Linear(d_model, n_vocab)

Masks

46        self.mask_x = None
47        self.mask_mem = None
49    def forward(self, x: torch.Tensor, mem: CompressedMemory):

Get memory and compressed memory

51        if mem is not None:
52            mem, c_mem = mem.mem, mem.c_mem
53        else:
54            mem = []
55            c_mem = []

Total length of the memory and compressed memory (for masks)

58        m_len = len(mem[0]) if mem else 0
59        if c_mem:
60            m_len += len(c_mem[0])

Create a subsequent mask for tokens

63        if self.mask_x is None or self.mask_x.shape[0] < len(x):
64            from labml_nn.transformers.utils import subsequent_mask
65            self.mask_x = subsequent_mask(len(x)).to(x.device)

Create an all ones (full visibility) mask for memory

67        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

Concatenate the masks if there is memory

71        if m_len:
72            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

Use only the subsequent mask otherwise

74        else:
75            mask = self.mask_x[:len(x), :len(x)]

Token embeddings

78        x = self.src_embed(x)

Run it through the transformer

80        res, mem = self.transformer(x, mem, c_mem, mask)

Generate logits of the next token

82        res = self.generator(res)

84        return res, mem

Configurations

The default configurations can and will be overridden when we start the experiment.

87class Configs(NLPAutoRegressionConfigs):
94    model: AutoregressiveModel

Token embedding size

97    d_model: int = 128

Number of attention heads

99    heads: int = 4

Dropout probability

101    dropout: float = 0.0

Number of features in FFN hidden layer

103    d_ff: int = 256

Number of transformer layers

105    n_layers: int = 6

Number of memories to keep

107    mem_len: int = 8

State module to maintain memories when switching between training and validation

109    memory = SimpleStateModule()

Attention Reconstruction Loss

111    attention_reconstruction_loss: AttentionReconstructionLoss

Compression rate

113    compression_rate: int = 4

Compressed memory length

115    c_mem_len: int = 128
117    def init(self):

Set tracker configurations

119        tracker.set_scalar("accuracy.*", True)
120        tracker.set_scalar("loss.*", True)

Do not print the attention reconstruction loss in the terminal

122        tracker.set_scalar("ar_loss.*", False)

Add a hook to log module outputs

124        hook_model_outputs(self.mode, self.model, 'model')

This will keep the accuracy metric stats and memories separate for training and validation.

126        self.state_modules = [self.accuracy, self.memory]

Concatenate new memories and compress the oldest memories.

128    @torch.no_grad()
129    def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130            -> Tuple[CompressedMemory, List[torch.Tensor]]:

If the configurations specify not to use memory

136        if self.mem_len == 0 and self.c_mem_len == 0:
137            return CompressedMemory([], []), []

Get memory and compressed memory

140        if mem is not None:
141            mem, c_mem = mem.mem, mem.c_mem
142        else:
143            mem, c_mem = [], []

Concatenate new memories with old memory

146        if mem:
147            mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148        else:
149            mem = new_mem

Compress the oldest memories if there are more memories than mem_len

152        if len(mem[0]) > self.mem_len:

Calculate the number of compressed memories to make , where is the number of memories we have and is the maximum number of memories we maintain (mem_len ).

156            n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate

Number of memories to compress

158            n_old = n_c_mem * self.compression_rate

A list to keep memories that need to be compressed for each layer.

160            mem_to_compress = []

A list to keep the memories that do not get compressed for each layer.

162            uncompressed_mem = []

Iterate through memories of each layer.

164            for m in mem:

Split the memories at

166                cm, m = torch.split(m, [n_old, len(m) - n_old])

Collect memories to compress

168                mem_to_compress.append(cm)

Collect remaining memories

170                uncompressed_mem.append(m)

Update the memories

172            mem = uncompressed_mem

Compress the memories

175            new_c_mem = []
176            for i, layer in enumerate(self.model.transformer.layers):
177                new_c_mem.append(layer.compress(mem_to_compress[i]))

Concatenate newly compressed memories with old compressed memories

180            if c_mem:
181                c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]

If there are no old compressed memories

183            else:
184                c_mem = new_c_mem

Truncate old memories

187            if len(c_mem[0]) > self.c_mem_len:
188                c_mem = [m[-self.c_mem_len:] for m in c_mem]

No memories are compressed if the number of memories is less than mem_len

190        else:
191            mem_to_compress = []

Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.

195        return CompressedMemory(mem, c_mem), mem_to_compress

Training/validation step

197    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

203        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

206        if self.mode.is_train:
207            tracker.add_global_step(data.shape[0] * data.shape[1])

Whether to capture model outputs

210        with self.mode.update(is_log_activations=batch_idx.is_last):

Get memories

212            mem = self.memory.get()

Run the model

214            output, new_mem = self.model(data, mem)

Merge and compress memory

216            mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)

Update memories

218            self.memory.set(mem)

Calculate and log cross entropy loss

221        loss = self.loss_func(output, target)
222        tracker.add("loss.", loss)

Calculate attention reconstruction loss if memories were compressed in this step

225        if mem_to_compress:

Get attention reconstruction loss

227            ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)

Track attention reconstruction loss

229            tracker.add("ar_loss.", ar_loss)

Add attention reconstruction loss to loss

231            loss = loss + ar_loss

Calculate and log accuracy

234        self.accuracy(output, target)
235        self.accuracy.track()

Train the model

238        if self.mode.is_train:

Calculate gradients

240            loss.backward()

Clip gradients

242            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

244            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

246            if batch_idx.is_last:
247                tracker.add('model', self.model)

Clear the gradients

249            self.optimizer.zero_grad()

Save the tracked metrics

252        tracker.save()

Sampling function to generate samples periodically while training

254    def sample(self):

Starting prompt

260        prompt = self.prompt

Collect output for printing

262        log = [(prompt, Text.subtle)]

memory

264        mem = CompressedMemory([], [])

Sample 25 tokens

266        for i in monit.iterate('Sample', 25):

Tokenize the prompt

268            data = self.text.text_to_i(prompt).unsqueeze(-1)

Move to device

270            data = data.to(self.device)

Get the model output

272            output, new_mem = self.model(data, mem)

Get the model prediction (greedy)

274            output = output.argmax(dim=-1).squeeze(1)

Add the prediction to prompt

276            prompt += self.prompt_separator + self.text.itos[output[-1]]

Only feed the last character to model in next iteration, rest will go in as memories

278            prompt = prompt[-1:]

Add the prediction for logging

280            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Update and compress memory

282            mem, _ = self.merge_compress_memory(mem, new_mem)

Print the sampled output

285        logger.log(log)

Initialize the auto-regressive model

288@option(Configs.model)
289def autoregressive_model(c: Configs):
293    from labml_nn.transformers.xl import RelativeMultiHeadAttention
294    from labml_nn.transformers.feed_forward import FeedForward
295    m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296        CompressiveTransformerLayer(d_model=c.d_model,
297                                    self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298                                    feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299                                    dropout_prob=c.dropout,
300                                    compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301    return m.to(c.device)

Initialize the attention reconstruction loss

304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309    return AttentionReconstructionLoss(c.model.transformer.layers)

Run the experiment

312def main():

Create experiment

317    experiment.create(name="compressive_transformer", comment='')

Create configs

319    conf = Configs()

Load configurations

321    experiment.configs(conf,

A dictionary of configurations to override

323                       {'tokenizer': 'character',
324                        'text': 'tiny_shakespeare',
325                        'optimizer.learning_rate': 2.5e-4,
326                        'optimizer.optimizer': 'AdamW',
327                        'prompt': 'It is',
328                        'prompt_separator': '',
329
330                        'train_loader': 'sequential_train_loader',
331                        'valid_loader': 'sequential_valid_loader',
332
333                        'seq_len': 8,
334                        'mem_len': 8,
335                        'epochs': 128,
336                        'batch_size': 32,
337                        'inner_iterations': 25,
338                        'compression_rate': 2,
339                        })

Set models for saving and loading

342    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

345    with experiment.start():

TrainValidConfigs.run

347        conf.run()

351if __name__ == '__main__':
352    main()