Train Fast Weights Transformer

This trains a fast weights transformer model for auto-regression.

Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.

Open In Colab

16import torch
17from torch import nn
18
19from labml import experiment
20from labml.configs import option
21from labml.utils.pytorch import get_modules
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

Auto regressive model

25class AutoregressiveModel(nn.Module):
30    def __init__(self, n_vocab: int, d_model: int, transformer: nn.Module):
31        super().__init__()

Token embedding module

33        self.src_embed = nn.Embedding(n_vocab, d_model)
34        self.transformer = transformer
35        self.generator = nn.Linear(d_model, n_vocab)
37    def forward(self, x: torch.Tensor):

Embed the tokens

39        x = self.src_embed(x)

Run it through the the transformer

41        res = self.transformer(x)

Generate logits of the next token

43        return self.generator(res), None

Configurations

The default configs can and will be over-ridden when we start the experiment

46class Configs(NLPAutoRegressionConfigs):
53    model: AutoregressiveModel
54
55    d_model: int = 512
56    nu: int = 1
57    heads: int = 8
58    dropout: float = 0.0
59    d_ff: int = 2048
60    n_layers: int = 6
63@option(Configs.model)
64def fast_weights_transformer(c: Configs):
68    from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
69        FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
70
71    from labml_nn.transformers.fast_weights import DPFP
72    return AutoregressiveModel(
73        c.n_tokens, c.d_model,
74        FastWeightsAttentionTransformer(
75            FastWeightsAttentionTransformerLayer(d_model=c.d_model,
76                                                 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
77                                                 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
78                                                 dropout_prob=c.dropout),
79            c.n_layers)).to(c.device)
82def main():

Create experiment

84    experiment.create(name="fast_weights_transformer")

Create configs

86    conf = Configs()

Load configurations

88    experiment.configs(conf,

A dictionary of configurations to override

90                       {'tokenizer': 'character',
91                        'text': 'tiny_shakespeare',
92                        'optimizer.learning_rate': 1.0,
93                        'optimizer.optimizer': 'Noam',
94                        'prompt': 'It is',
95                        'prompt_separator': '',
96
97                        'train_loader': 'shuffled_train_loader',
98                        'valid_loader': 'shuffled_valid_loader',
99
100                        'seq_len': 128,
101                        'epochs': 128,
102                        'batch_size': 16,
103                        'inner_iterations': 25})

Set models for saving and loading

106    experiment.add_pytorch_models(get_modules(conf))

Start the experiment

109    with experiment.start():

Run the training loop

111        conf.run()
112
113
114if __name__ == '__main__':
115    main()