This is an annotated PyTorch experiment to train a AFT model.
This is based on general training loop and configurations for auto-regressive NLP task.
14import torch
15from labml import experiment
16from labml.configs import option
17from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
18from labml_nn.transformers import TransformerConfigs, Encoder
19from labml_nn.transformers.utils import subsequent_mask
20from torch import nnThis consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
23class AutoregressiveTransformer(nn.Module):encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.31 def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):38 super().__init__()
39 self.src_embed = src_embed
40 self.encoder = encoder
41 self.generator = generatorThe mask will be initialized on the first call
44 self.mask = None46 def forward(self, x: torch.Tensor):Create subsequent mask if mask is not initialized or if the size of the mask is different
49 if self.mask is None or self.mask.size(0) != len(x):Subsequent mask, will mask out tokens from seeing future tokens
51 self.mask = subsequent_mask(len(x)).to(x.device)Get the token embeddings with positional encodings
54 x = self.src_embed(x)Transformer encoder
56 x = self.encoder(x, self.mask)Get logits
58 x = self.generator(x)Return results (second value is for state, since our trainer is used with RNNs also)
62 return x, None65class Configs(NLPAutoRegressionConfigs):GPT model
74 model: AutoregressiveTransformerTransformer
76 transformer: TransformerConfigs
77
78 local_window_size: int = 3281@option(Configs.transformer, 'Transformer')
82def _transformer_configs(c: Configs):We use our configurable transformer implementation
89 conf = TransformerConfigs()Set the vocabulary sizes for embeddings and generating logits
91 conf.n_src_vocab = c.n_tokens
92 conf.n_tgt_vocab = c.n_tokensSet the embedding size
94 conf.d_model = c.d_modelReplace self-attention with an AFT Local Module
96 from labml_nn.transformers.aft import AFTLocal
97 conf.encoder_attn = AFTLocal(c.d_model, c.seq_len, c.local_window_size)100 return confCreate an auto-regressive model
103@option(Configs.model)
104def _model(c: Configs):108 m = AutoregressiveTransformer(c.transformer.encoder,
109 c.transformer.src_embed,
110 c.transformer.generator).to(c.device)
111
112 return m115def main():Create experiment
117 experiment.create(name="aft")Create configs
119 conf = Configs()Override configurations
121 experiment.configs(conf, {Use character level tokenizer
123 'tokenizer': 'character',Prompt separator is blank
125 'prompt_separator': '',Starting prompt for sampling
127 'prompt': 'It is ',Use Tiny Shakespeare dataset
129 'text': 'tiny_shakespeare',Use a context size of
132 'seq_len': 256,Train for epochs
134 'epochs': 128,Batch size
136 'batch_size': 32,Switch between training and validation for times per epoch
139 'inner_iterations': 10,Embedding size
142 'd_model': 128,FFN hidden dimension size
144 'transformer.ffn.d_ff': 256,Optimizer
147 'optimizer.optimizer': 'Noam',
148 'optimizer.learning_rate': 1.,
149 })Set models for saving and loading
152 experiment.add_pytorch_models({'model': conf.model})Start the experiment
155 with experiment.start():Run training
157 conf.run()161if __name__ == '__main__':
162 main()