11from labml import experiment
12from labml.configs import calculate
13from labml_nn.experiments.arithmetic_dataset import ArithmeticAutoregression
14from labml_nn.transformers import TransformerConfigs
15from labml_nn.transformers.rope.experiment import Configs as RoPEConfigs
We inherit RoPE experiment and use it for arithmetic addition task.
We add the option to change attention to use Rotary Positional Embeddings with Relative distance (RoPER) below.
18class Configs(RoPEConfigs, ArithmeticAutoregression):
26 pass
29def _rotary_value_pe_mha(c: TransformerConfigs):
33 from labml_nn.transformers.rope.value_pe import RotaryValuePEMultiHeadAttention
34 return RotaryValuePEMultiHeadAttention(c.n_heads, c.d_model, 1., 1.)
Configuration options
38calculate(TransformerConfigs.encoder_attn, 'rotary_value', _rotary_value_pe_mha)
39calculate(TransformerConfigs.decoder_attn, 'rotary_value', _rotary_value_pe_mha)
40calculate(TransformerConfigs.decoder_mem_attn, 'rotary_value', _rotary_value_pe_mha)
43def main():
Create experiment
45 experiment.create(name="roper_addition", comment="rotary value 7", writers={'screen', 'labml'})
Create configs
47 conf = Configs()
Override configurations
49 experiment.configs(conf, {
50 'max_digits': 7,
No fixed positional embeddings
53 'transformer.src_embed': 'no_pos',
54 'transformer.tgt_embed': 'no_pos',
Encoder with RoPER attention
57 'transformer.encoder_attn': 'rotary_value',
Encoder with RoPE attention 'transformer.encoder_attn': 'rotary',
62 'model': 'rotary_pe_transformer',
Use a context size of
65 'seq_len': 512,
Train for 32 epochs
67 'epochs': 20,
Batch size
69 'batch_size': 16,
Model size
72 'd_model': 128,
73 'transformer.ffn.d_ff': 512,
74 'transformer.n_heads': 4,
75 'transformer.dropout': 0.0,
Use Adam optimizer
78 'optimizer.optimizer': 'Adam',
79 'optimizer.learning_rate': 2.5e-4,
80 })
Set models for saving and loading
83 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
86 with experiment.start():
Run training
88 conf.run()
92if __name__ == '__main__':
93 main()