Compressive Transformer Experiment

This is an annotated PyTorch experiment to train a compressive transformer model.

11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15from labml import experiment, tracker, monit, logger
16from labml.configs import option
17from labml.logger import Text
18from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
19from labml_nn.helpers.metrics import SimpleStateModule
20from labml_nn.helpers.trainer import BatchIndex
21from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
22    CompressiveTransformerLayer, Conv1dCompression
25class CompressedMemory(NamedTuple):
26    mem: List[torch.Tensor]
27    c_mem: List[torch.Tensor]

Auto regressive model

30class AutoregressiveModel(nn.Module):
35    def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
36        super().__init__()

Token embedding module

38        self.src_embed = nn.Embedding(n_vocab, d_model)

Transformer

40        self.transformer = transformer

Final layer

42        self.generator = nn.Linear(d_model, n_vocab)

Masks

44        self.mask_x = None
45        self.mask_mem = None
47    def forward(self, x: torch.Tensor, mem: CompressedMemory):

Get memory and compressed memory

49        if mem is not None:
50            mem, c_mem = mem.mem, mem.c_mem
51        else:
52            mem = []
53            c_mem = []

Total length of the memory and compressed memory (for masks)

56        m_len = len(mem[0]) if mem else 0
57        if c_mem:
58            m_len += len(c_mem[0])

Create a subsequent mask for tokens

61        if self.mask_x is None or self.mask_x.shape[0] < len(x):
62            from labml_nn.transformers.utils import subsequent_mask
63            self.mask_x = subsequent_mask(len(x)).to(x.device)

Create an all ones (full visibility) mask for memory

65        if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
66            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

Concatenate the masks if there is memory

69        if m_len:
70            mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)

Use only the subsequent mask otherwise

72        else:
73            mask = self.mask_x[:len(x), :len(x)]

Token embeddings

76        x = self.src_embed(x)

Run it through the transformer

78        res, mem = self.transformer(x, mem, c_mem, mask)

Generate logits of the next token

80        res = self.generator(res)

82        return res, mem

Configurations

The default configurations can and will be overridden when we start the experiment.

85class Configs(NLPAutoRegressionConfigs):
92    model: AutoregressiveModel

Token embedding size

95    d_model: int = 128

Number of attention heads

97    heads: int = 4

Dropout probability

99    dropout: float = 0.0

Number of features in FFN hidden layer

101    d_ff: int = 256

Number of transformer layers

103    n_layers: int = 6

Number of memories to keep

105    mem_len: int = 8

State module to maintain memories when switching between training and validation

107    memory = SimpleStateModule()

Attention Reconstruction Loss

109    attention_reconstruction_loss: AttentionReconstructionLoss

Compression rate

111    compression_rate: int = 4

Compressed memory length

113    c_mem_len: int = 128
115    def init(self):

Set tracker configurations

117        tracker.set_scalar("accuracy.*", True)
118        tracker.set_scalar("loss.*", True)

Do not print the attention reconstruction loss in the terminal

120        tracker.set_scalar("ar_loss.*", False)

This will keep the accuracy metric stats and memories separate for training and validation.

122        self.state_modules = [self.accuracy, self.memory]

Concatenate new memories and compress the oldest memories.

124    @torch.no_grad()
125    def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
126            -> Tuple[CompressedMemory, List[torch.Tensor]]:

If the configurations specify not to use memory

132        if self.mem_len == 0 and self.c_mem_len == 0:
133            return CompressedMemory([], []), []

Get memory and compressed memory

136        if mem is not None:
137            mem, c_mem = mem.mem, mem.c_mem
138        else:
139            mem, c_mem = [], []

Concatenate new memories with old memory

142        if mem:
143            mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
144        else:
145            mem = new_mem

Compress the oldest memories if there are more memories than mem_len

148        if len(mem[0]) > self.mem_len:

Calculate the number of compressed memories to make , where is the number of memories we have and is the maximum number of memories we maintain (mem_len ).

152            n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate

Number of memories to compress

154            n_old = n_c_mem * self.compression_rate

A list to keep memories that need to be compressed for each layer.

156            mem_to_compress = []

A list to keep the memories that do not get compressed for each layer.

158            uncompressed_mem = []

Iterate through memories of each layer.

160            for m in mem:

Split the memories at

162                cm, m = torch.split(m, [n_old, len(m) - n_old])

Collect memories to compress

164                mem_to_compress.append(cm)

Collect remaining memories

166                uncompressed_mem.append(m)

Update the memories

168            mem = uncompressed_mem

Compress the memories

171            new_c_mem = []
172            for i, layer in enumerate(self.model.transformer.layers):
173                new_c_mem.append(layer.compress(mem_to_compress[i]))

Concatenate newly compressed memories with old compressed memories

176            if c_mem:
177                c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]

If there are no old compressed memories

179            else:
180                c_mem = new_c_mem

Truncate old memories

183            if len(c_mem[0]) > self.c_mem_len:
184                c_mem = [m[-self.c_mem_len:] for m in c_mem]

No memories are compressed if the number of memories is less than mem_len

186        else:
187            mem_to_compress = []

Return memories and the memories that were compressed. Memories that were compressed are needed for the reconstruction loss computation.

191        return CompressedMemory(mem, c_mem), mem_to_compress

Training/validation step

193    def step(self, batch: any, batch_idx: BatchIndex):

Move data to the device

199        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of tokens processed) when in training mode

202        if self.mode.is_train:
203            tracker.add_global_step(data.shape[0] * data.shape[1])

Get memories

206        mem = self.memory.get()

Run the model

208        output, new_mem = self.model(data, mem)

Merge and compress memory

210        mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)

Update memories

212        self.memory.set(mem)

Calculate and log cross entropy loss

215        loss = self.loss_func(output, target)
216        tracker.add("loss.", loss)

Calculate attention reconstruction loss if memories were compressed in this step

219        if mem_to_compress:

Get attention reconstruction loss

221            ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)

Track attention reconstruction loss

223            tracker.add("ar_loss.", ar_loss)

Add attention reconstruction loss to loss

225            loss = loss + ar_loss

Calculate and log accuracy

228        self.accuracy(output, target)
229        self.accuracy.track()

Train the model

232        if self.mode.is_train:

Calculate gradients

234            loss.backward()

Clip gradients

236            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

238            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

240            if batch_idx.is_last:
241                tracker.add('model', self.model)

Clear the gradients

243            self.optimizer.zero_grad()

Save the tracked metrics

246        tracker.save()

Sampling function to generate samples periodically while training

248    def sample(self):

Starting prompt

254        prompt = self.prompt

Collect output for printing

256        log = [(prompt, Text.subtle)]

memory

258        mem = CompressedMemory([], [])

Sample 25 tokens

260        for i in monit.iterate('Sample', 25):

Tokenize the prompt

262            data = self.text.text_to_i(prompt).unsqueeze(-1)

Move to device

264            data = data.to(self.device)

Get the model output

266            output, new_mem = self.model(data, mem)

Get the model prediction (greedy)

268            output = output.argmax(dim=-1).squeeze(1)

Add the prediction to prompt

270            prompt += self.prompt_separator + self.text.itos[output[-1]]

Only feed the last character to model in next iteration, rest will go in as memories

272            prompt = prompt[-1:]

Add the prediction for logging

274            log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]

Update and compress memory

276            mem, _ = self.merge_compress_memory(mem, new_mem)

Print the sampled output

279        logger.log(log)

Initialize the auto-regressive model

282@option(Configs.model)
283def autoregressive_model(c: Configs):
287    from labml_nn.transformers.xl import RelativeMultiHeadAttention
288    from labml_nn.transformers.feed_forward import FeedForward
289    m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
290        CompressiveTransformerLayer(d_model=c.d_model,
291                                    self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
292                                    feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
293                                    dropout_prob=c.dropout,
294                                    compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
295    return m.to(c.device)

Initialize the attention reconstruction loss

298@option(Configs.attention_reconstruction_loss)
299def attention_reconstruction_loss(c: Configs):
303    return AttentionReconstructionLoss(c.model.transformer.layers)

Run the experiment

306def main():

Create experiment

311    experiment.create(name="compressive_transformer", comment='')

Create configs

313    conf = Configs()

Load configurations

315    experiment.configs(conf,

A dictionary of configurations to override

317                       {'tokenizer': 'character',
318                        'text': 'tiny_shakespeare',
319                        'optimizer.learning_rate': 2.5e-4,
320                        'optimizer.optimizer': 'AdamW',
321                        'prompt': 'It is',
322                        'prompt_separator': '',
323
324                        'train_loader': 'sequential_train_loader',
325                        'valid_loader': 'sequential_valid_loader',
326
327                        'seq_len': 8,
328                        'mem_len': 8,
329                        'epochs': 128,
330                        'batch_size': 32,
331                        'inner_iterations': 25,
332                        'compression_rate': 2,
333                        })

Set models for saving and loading

336    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

339    with experiment.start():

TrainValidConfigs.run

341        conf.run()

345if __name__ == '__main__':
346    main()