This is an annotated PyTorch experiment to train a FNet model.
This is based on general training loop and configurations for AG News classification task.
15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
22from labml_nn.transformers import Encoder
23from labml_nn.transformers import TransformerConfigs
26class TransformerClassifier(nn.Module):
encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.30 def __init__(self, encoder: Encoder, src_embed: Module, generator: nn.Linear):
37 super().__init__()
38 self.src_embed = src_embed
39 self.encoder = encoder
40 self.generator = generator
42 def forward(self, x: torch.Tensor):
Get the token embeddings with positional encodings
44 x = self.src_embed(x)
Transformer encoder
46 x = self.encoder(x, None)
Get logits for classification.
We set the [CLS]
token at the last position of the sequence. This is extracted by x[-1]
, where x
is of shape [seq_len, batch_size, d_model]
52 x = self.generator(x[-1])
Return results (second value is for state, since our trainer is used with RNNs also)
56 return x, None
59class Configs(NLPClassificationConfigs):
Classification model
68 model: TransformerClassifier
Transformer
70 transformer: TransformerConfigs
73@option(Configs.transformer)
74def _transformer_configs(c: Configs):
We use our configurable transformer implementation
81 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
83 conf.n_src_vocab = c.n_tokens
84 conf.n_tgt_vocab = c.n_tokens
87 return conf
Create FNetMix
module that can replace the self-attention in transformer encoder layer .
90@option(TransformerConfigs.encoder_attn)
91def fnet_mix():
97 from labml_nn.transformers.fnet import FNetMix
98 return FNetMix()
Create classification model
101@option(Configs.model)
102def _model(c: Configs):
106 m = TransformerClassifier(c.transformer.encoder,
107 c.transformer.src_embed,
108 nn.Linear(c.d_model, c.n_classes)).to(c.device)
109
110 return m
113def main():
Create experiment
115 experiment.create(name="fnet")
Create configs
117 conf = Configs()
Override configurations
119 experiment.configs(conf, {
Use world level tokenizer
121 'tokenizer': 'basic_english',
Train for epochs
124 'epochs': 32,
Switch between training and validation for times per epoch
127 'inner_iterations': 10,
Transformer configurations (same as defaults)
130 'transformer.d_model': 512,
131 'transformer.ffn.d_ff': 2048,
132 'transformer.n_heads': 8,
133 'transformer.n_layers': 6,
Use Noam optimizer
140 'optimizer.optimizer': 'Noam',
141 'optimizer.learning_rate': 1.,
142 })
Set models for saving and loading
145 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
148 with experiment.start():
Run training
150 conf.run()
154if __name__ == '__main__':
155 main()