Train Feedback Transformer

This trains a feedback transformer model for auto-regression. You can pick the original feedback transformer or the new version where the keys and values are precalculated.

Here's a Colab notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab View Run

19import torch
20from torch import nn
21
22from labml import experiment
23from labml.configs import option
24from labml.utils.pytorch import get_modules
25from labml_helpers.module import Module
26
27from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
28from labml_nn.transformers import Encoder, Generator, TransformerConfigs
29from labml_nn.transformers.utils import subsequent_mask

Auto regressive model

32class AutoregressiveModel(Module):
37    def __init__(self, n_vocab: int, d_model: int, transformer: Module):
38        super().__init__()

Token embedding module

40        self.src_embed = nn.Embedding(n_vocab, d_model)
41        self.transformer = transformer
42        self.generator = nn.Linear(d_model, n_vocab)
44    def forward(self, x: torch.Tensor):

Embed the tokens

46        x = self.src_embed(x)

Run it through the the transformer

48        res = self.transformer(x)

Generate logits of the next token

50        return self.generator(res), None

Configurations

The default configs can and will be over-ridden when we start the experiment

53class Configs(NLPAutoRegressionConfigs):
60    model: AutoregressiveModel
61
62    d_model: int = 512
63    heads: int = 8
64    dropout: float = 0.0
65    d_ff: int = 2048
66    n_layers: int = 6
69@option(Configs.model)
70def feedback_transformer(c: Configs):
74    from labml_nn.transformers.feedback import FeedbackTransformer, FeedbackTransformerLayer, \
75        FeedbackAttention, FeedForward
76
77    return AutoregressiveModel(
78        c.n_tokens, c.d_model,
79        FeedbackTransformer(
80            FeedbackTransformerLayer(d_model=c.d_model,
81                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout),
82                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
83                                     dropout_prob=c.dropout),
84            c.n_layers)).to(c.device)

Create updated feedback transformer, with precalculated keys and values.

87@option(Configs.model)
88def feedback_transformer_kv(c: Configs):
92    from labml_nn.transformers.feedback import FeedbackTransformerKV, FeedbackTransformerLayer, \
93        FeedbackAttention, FeedForward
94
95    return AutoregressiveModel(
96        c.n_tokens, c.d_model,
97        FeedbackTransformerKV(
98            FeedbackTransformerLayer(d_model=c.d_model,
99                                     attn=FeedbackAttention(c.heads, c.d_model, c.dropout,
100                                                            is_kv_precomputed=True),
101                                     feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
102                                     dropout_prob=c.dropout),
103            c.n_layers, c.d_model, c.heads)).to(c.device)
106def main():

Create experiment

108    experiment.create(name="feedback_transformer")

Create configs

110    conf = Configs()

Load configurations

112    experiment.configs(conf,

A dictionary of configurations to override

114                       {'tokenizer': 'character',
115                        'text': 'tiny_shakespeare',
116                        'optimizer.learning_rate': 1.0,
117                        'optimizer.optimizer': 'Noam',
118                        'prompt': 'It is',
119                        'prompt_separator': '',

Use feedback_transformer for original feedback transformer

122                        'model': 'feedback_transformer_kv',
123
124                        'train_loader': 'shuffled_train_loader',
125                        'valid_loader': 'shuffled_valid_loader',
126
127                        'seq_len': 128,
128                        'epochs': 128,
129                        'batch_size': 64,
130                        'inner_iterations': 25})

Set models for saving and loading

133    experiment.add_pytorch_models(get_modules(conf))

Start the experiment

136    with experiment.start():

Run the training loop

138        conf.run()
139
140
141if __name__ == '__main__':
142    main()