これは、プライマーEZトランスフォーマーをトレーニングするための注釈付きのPyTorch実験です。
これは私たちのバニラトランスフォーマー実験に基づいています。同じ実験を行い、Primer EZの修正を加えます
。15from labml import experiment
16from labml.configs import option
17from labml_nn.transformers import TransformerConfigs
18from labml_nn.transformers.basic.autoregressive_experiment import Configs
19from labml_nn.transformers.configs import FeedForwardConfigs
20from labml_nn.transformers.primer_ez import SquaredReLU
23@option(FeedForwardConfigs.activation, 'SquaredReLU')
24def _squared_relu():
30 return SquaredReLU()
33@option(TransformerConfigs.encoder_attn, 'MultiDConvHeadAttention')
34def _d_conv_mha(c: TransformerConfigs):
40 from labml_nn.transformers.primer_ez import MultiDConvHeadAttention
41 return MultiDConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
44@option(TransformerConfigs.encoder_attn, 'MultiDSharedConvHeadAttention')
45def _d_shared_conv_mha(c: TransformerConfigs):
53 from labml_nn.transformers.primer_ez.variations import MultiDSharedConvHeadAttention
54 return MultiDSharedConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
57@option(TransformerConfigs.encoder_attn, 'MultiDPHConvHeadAttention')
58def _d_per_head_conv_mha(c: TransformerConfigs):
66 from labml_nn.transformers.primer_ez.variations import MultiDPHConvHeadAttention
67 return MultiDPHConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
70def main():
実験を作成
72 experiment.create(name="primer_ez")
コンフィグの作成
74 conf = Configs()
オーバーライド設定
76 experiment.configs(conf, {
キャラクターレベルのトークナイザーを使う
78 'tokenizer': 'character',
プロンプトセパレータが空白
80 'prompt_separator': '',
サンプリングの開始プロンプト
82 'prompt': 'It is ',
タイニー・シェイクスピア・データセットを使う
84 'text': 'tiny_shakespeare',
コンテキストサイズを次の値にしてください
87 'seq_len': 256,
時代に合わせた列車
89 'epochs': 128,
バッチサイズ
91 'batch_size': 32,
エポックごとにトレーニングと検証を切り替える
94 'inner_iterations': 10,
モデルサイズ
97 'd_model': 512,
98 'transformer.ffn.d_ff': 2048,
Adam オプティマイザーを使う
101 'optimizer.optimizer': 'Adam',
102 'optimizer.learning_rate': 2.5e-4,
107 'transformer.ffn.activation': 'SquaredReLU',
112 'transformer.encoder_attn': 'MultiDConvHeadAttention',
113 })
保存および読み込み用のモデルを設定する
116 experiment.add_pytorch_models({'model': conf.model})
実験を始める
119 with experiment.start():
トレーニングを実行
121 conf.run()
125if __name__ == '__main__':
126 main()