Train Fast Weights Transformer

This trains a fast weights transformer model for auto-regression.

Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.

Open In Colab View Run

17import torch
18from torch import nn
19
20from labml import experiment
21from labml.configs import option
22from labml.utils.pytorch import get_modules
23from labml_helpers.module import Module
24from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

Auto regressive model

27class AutoregressiveModel(Module):
32    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
33        super().__init__()

Token embedding module

35        self.src_embed = nn.Embedding(n_vocab, d_model)
36        self.transformer = transformer
37        self.generator = nn.Linear(d_model, n_vocab)
39    def forward(self, x: torch.Tensor):

Embed the tokens

41        x = self.src_embed(x)

Run it through the the transformer

43        res = self.transformer(x)

Generate logits of the next token

45        return self.generator(res), None

Configurations

The default configs can and will be over-ridden when we start the experiment

48class Configs(NLPAutoRegressionConfigs):
55    model: AutoregressiveModel
56
57    d_model: int = 512
58    nu: int = 1
59    heads: int = 8
60    dropout: float = 0.0
61    d_ff: int = 2048
62    n_layers: int = 6
65@option(Configs.model)
66def fast_weights_transformer(c: Configs):
70    from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
71        FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
72
73    from labml_nn.transformers.fast_weights import DPFP
74    return AutoregressiveModel(
75        c.n_tokens, c.d_model,
76        FastWeightsAttentionTransformer(
77            FastWeightsAttentionTransformerLayer(d_model=c.d_model,
78                                                 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
79                                                 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
80                                                 dropout_prob=c.dropout),
81            c.n_layers)).to(c.device)
84def main():

Create experiment

86    experiment.create(name="fast_weights_transformer")

Create configs

88    conf = Configs()

Load configurations

90    experiment.configs(conf,

A dictionary of configurations to override

92                       {'tokenizer': 'character',
93                        'text': 'tiny_shakespeare',
94                        'optimizer.learning_rate': 1.0,
95                        'optimizer.optimizer': 'Noam',
96                        'prompt': 'It is',
97                        'prompt_separator': '',
98
99                        'train_loader': 'shuffled_train_loader',
100                        'valid_loader': 'shuffled_valid_loader',
101
102                        'seq_len': 128,
103                        'epochs': 128,
104                        'batch_size': 16,
105                        'inner_iterations': 25})

Set models for saving and loading

108    experiment.add_pytorch_models(get_modules(conf))

Start the experiment

111    with experiment.start():

Run the training loop

113        conf.run()
114
115
116if __name__ == '__main__':
117    main()