This trains a fast weights transformer model for auto-regression.
Here’s a Colab notebook for training a fast weights transformer on Tiny Shakespeare dataset.
16import torch
17from torch import nn
18
19from labml import experiment
20from labml.configs import option
21from labml.utils.pytorch import get_modules
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
25class AutoregressiveModel(nn.Module):
30 def __init__(self, n_vocab: int, d_model: int, transformer: nn.Module):
31 super().__init__()
Token embedding module
33 self.src_embed = nn.Embedding(n_vocab, d_model)
34 self.transformer = transformer
35 self.generator = nn.Linear(d_model, n_vocab)
37 def forward(self, x: torch.Tensor):
Embed the tokens
39 x = self.src_embed(x)
Run it through the the transformer
41 res = self.transformer(x)
Generate logits of the next token
43 return self.generator(res), None
46class Configs(NLPAutoRegressionConfigs):
53 model: AutoregressiveModel
54
55 d_model: int = 512
56 nu: int = 1
57 heads: int = 8
58 dropout: float = 0.0
59 d_ff: int = 2048
60 n_layers: int = 6
Create fast weights transformer.
63@option(Configs.model)
64def fast_weights_transformer(c: Configs):
68 from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
69 FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
70
71 from labml_nn.transformers.fast_weights import DPFP
72 return AutoregressiveModel(
73 c.n_tokens, c.d_model,
74 FastWeightsAttentionTransformer(
75 FastWeightsAttentionTransformerLayer(d_model=c.d_model,
76 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
77 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
78 dropout_prob=c.dropout),
79 c.n_layers)).to(c.device)
82def main():
Create experiment
84 experiment.create(name="fast_weights_transformer")
Create configs
86 conf = Configs()
Load configurations
88 experiment.configs(conf,
A dictionary of configurations to override
90 {'tokenizer': 'character',
91 'text': 'tiny_shakespeare',
92 'optimizer.learning_rate': 1.0,
93 'optimizer.optimizer': 'Noam',
94 'prompt': 'It is',
95 'prompt_separator': '',
96
97 'train_loader': 'shuffled_train_loader',
98 'valid_loader': 'shuffled_valid_loader',
99
100 'seq_len': 128,
101 'epochs': 128,
102 'batch_size': 16,
103 'inner_iterations': 25})
Set models for saving and loading
106 experiment.add_pytorch_models(get_modules(conf))
Start the experiment
109 with experiment.start():
Run the training loop
111 conf.run()
112
113
114if __name__ == '__main__':
115 main()