Pay Attention to MLPs (gMLP) Experiment

This is an annotated PyTorch experiment to train a gMLP model. The paper also applies a Stochastic Depth regularization where some layers are removed randomly during training. We have not implemented that here.

This is based on training loop and configurations for a simple transformer auto-regressive NLP task.

View Run

18from labml import experiment
19from labml.configs import option
20from labml_nn.transformers import TransformerConfigs
21from labml_nn.transformers.basic.autoregressive_experiment import Configs as BasicAutoRegressionConfigs
22from labml_nn.transformers.gmlp import GMLPBlock
25class Configs(BasicAutoRegressionConfigs):

Transformer

34    transformer: TransformerConfigs = 'gMLP'

gMLP Block

36    gmlp: GMLPBlock

d_ffn for gMLP projection layer

38    d_ffn: int = 2048

Create a gMLP block

41@option(Configs.gmlp, 'gMLP')
42def _gmlp_configs(c: Configs):
46    return GMLPBlock(c.d_model, c.d_ffn, c.seq_len)

Transformer configurations

49@option(Configs.transformer, 'gMLP')
50def _transformer_configs(c: Configs):
57    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

59    conf.n_src_vocab = c.n_tokens
60    conf.n_tgt_vocab = c.n_tokens

Set model size

62    conf.d_model = c.d_model

Replace the encoder layer with a gMLP layer

64    conf.encoder_layer = c.gmlp
65
66    return conf
69def main():

Create experiment

71    experiment.create(name="gMLP")

Create configs

73    conf = Configs()

Override configurations

75    experiment.configs(conf, {

Use character level tokenizer

77        'tokenizer': 'character',

Prompt separator is blank

79        'prompt_separator': '',

Starting prompt for sampling

81        'prompt': 'It is ',

Use Tiny Shakespeare dataset

83        'text': 'tiny_shakespeare',

Use a context size of $256$

86        'seq_len': 256,

Train for $128$ epochs

88        'epochs': 128,

Batch size $32$

90        'batch_size': 32,

Switch between training and validation for $10$ times per epoch

93        'inner_iterations': 10,

Model size

96        'd_model': 512,
97        'd_ffn': 2048,
100        'optimizer.optimizer': 'Noam',
101        'optimizer.learning_rate': 1.,
102    })

Set models for saving and loading

105    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

108    with experiment.start():

Run training

110        conf.run()
114if __name__ == '__main__':
115    main()