これは、プライマーEZトランスフォーマーをトレーニングするための注釈付きのPyTorch実験です。
これは私たちのバニラトランスフォーマー実験に基づいています。同じ実験を行い、Primer EZの修正を加えます
。15from labml import experiment
16from labml.configs import option
17from labml_nn.transformers import TransformerConfigs
18from labml_nn.transformers.basic.autoregressive_experiment import Configs
19from labml_nn.transformers.configs import FeedForwardConfigs
20from labml_nn.transformers.primer_ez import SquaredReLU23@option(FeedForwardConfigs.activation, 'SquaredReLU')
24def _squared_relu():30 return SquaredReLU()33@option(TransformerConfigs.encoder_attn, 'MultiDConvHeadAttention')
34def _d_conv_mha(c: TransformerConfigs):40 from labml_nn.transformers.primer_ez import MultiDConvHeadAttention
41 return MultiDConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)44@option(TransformerConfigs.encoder_attn, 'MultiDSharedConvHeadAttention')
45def _d_shared_conv_mha(c: TransformerConfigs):53 from labml_nn.transformers.primer_ez.variations import MultiDSharedConvHeadAttention
54 return MultiDSharedConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)57@option(TransformerConfigs.encoder_attn, 'MultiDPHConvHeadAttention')
58def _d_per_head_conv_mha(c: TransformerConfigs):66 from labml_nn.transformers.primer_ez.variations import MultiDPHConvHeadAttention
67 return MultiDPHConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)70def main():実験を作成
72 experiment.create(name="primer_ez")コンフィグの作成
74 conf = Configs()オーバーライド設定
76 experiment.configs(conf, {キャラクターレベルのトークナイザーを使う
78 'tokenizer': 'character',プロンプトセパレータが空白
80 'prompt_separator': '',サンプリングの開始プロンプト
82 'prompt': 'It is ',タイニー・シェイクスピア・データセットを使う
84 'text': 'tiny_shakespeare',コンテキストサイズを次の値にしてください
87 'seq_len': 256,時代に合わせた列車
89 'epochs': 128,バッチサイズ
91 'batch_size': 32,エポックごとにトレーニングと検証を切り替える
94 'inner_iterations': 10,モデルサイズ
97 'd_model': 512,
98 'transformer.ffn.d_ff': 2048,Adam オプティマイザーを使う
101 'optimizer.optimizer': 'Adam',
102 'optimizer.learning_rate': 2.5e-4,107 'transformer.ffn.activation': 'SquaredReLU',112 'transformer.encoder_attn': 'MultiDConvHeadAttention',
113 })保存および読み込み用のモデルを設定する
116 experiment.add_pytorch_models({'model': conf.model})実験を始める
119 with experiment.start():トレーニングを実行
121 conf.run()125if __name__ == '__main__':
126 main()