This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network. The reusable & configurable are defined in configs.py
.
16import torch
17from labml import experiment
18from labml.configs import option
19from labml.utils.pytorch import get_modules
20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
21from labml_nn.transformers import Encoder, Generator, TransformerConfigs
22from labml_nn.transformers.utils import subsequent_mask
23from torch import nn
26class AutoregressiveModel(nn.Module):
31 def __init__(self, src_embed: nn.Module, encoder: Encoder, generator: Generator):
32 super().__init__()
Token embedding module
34 self.src_embed = src_embed
Transformer based encoder
36 self.encoder = encoder
Next token generation layer; this give logits of the the next token
39 self.generator = generator
This will be initialized on the first call
41 self.src_mask = None
43 def forward(self, src: torch.Tensor):
Create subsequent mask, so that the transformer can only pay attention to past tokens.
45 if self.src_mask is None or self.src_mask.size(0) != len(src):
46 self.src_mask = subsequent_mask(len(src)).to(src.device)
Embed the tokens (src
) and run it through the the transformer
48 res = self.encoder(self.src_embed(src), self.src_mask)
Generate logits of the next token
50 return self.generator(res), None
53class Configs(NLPAutoRegressionConfigs):
60 transformer: TransformerConfigs
61 model: AutoregressiveModel
Initialize the auto-regressive model
64@option(Configs.model)
65def autoregressive_model(c: Configs):
69 m = AutoregressiveModel(c.transformer.src_embed, c.transformer.encoder, c.transformer.generator)
70 return m.to(c.device)
Initialize the configurable transformer encoder for our autoregressive model.
73@option(Configs.transformer)
74def transformer_c(c: Configs):
78 tc = TransformerConfigs()
79 tc.n_src_vocab = c.n_tokens
80 tc.n_tgt_vocab = c.n_tokens
81
82 return tc
85def main():
Create experiment
87 experiment.create(name="glu_variants")
Create configs
89 conf = Configs()
Load configurations
91 experiment.configs(conf,
A dictionary of configurations to override
93 {'tokenizer': 'character',
94 'prompt_separator': '',
95 'prompt': 'It is ',
96 'text': 'tiny_shakespeare',
97
98 'optimizer.optimizer': 'Noam',
99 'optimizer.learning_rate': 1.,
100 'optimizer.d_model': 256,
101
102 'seq_len': 1024,
103 'epochs': 128,
104 'batch_size': 6,
105 'inner_iterations': 10,
GLU Variant, one of GLU, Bilinear, ReGLU, GEGLU, SwiGLU
These are defined in the configurable FFN implementation
111 'transformer.ffn.glu_variant': 'Bilinear',
Transformer configurations
114 'transformer.d_model': 256,
115 'transformer.ffn.d_ff': 1024,
116 'transformer.n_heads': 8,
117 'transformer.n_layers': 6})
This is needed to initialize models
120 conf.n_tokens = conf.text.n_tokens
Set models for saving and loading
123 experiment.add_pytorch_models(get_modules(conf))
Start the experiment
126 with experiment.start():
TrainValidConfigs.run
128 conf.run()
129
130
131if __name__ == '__main__':
132 main()