15from labml import experiment
16from labml.configs import option
17from labml_nn.transformers import TransformerConfigs
18from labml_nn.transformers.basic.autoregressive_experiment import Configs
19from labml_nn.transformers.configs import FeedForwardConfigs
20from labml_nn.transformers.primer_ez import SquaredReLU
23@option(FeedForwardConfigs.activation, 'SquaredReLU')
24def _squared_relu():
30    return SquaredReLU()
33@option(TransformerConfigs.encoder_attn, 'MultiDConvHeadAttention')
34def _d_conv_mha(c: TransformerConfigs):
40    from labml_nn.transformers.primer_ez import MultiDConvHeadAttention
41    return MultiDConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
44@option(TransformerConfigs.encoder_attn, 'MultiDSharedConvHeadAttention')
45def _d_shared_conv_mha(c: TransformerConfigs):
53    from labml_nn.transformers.primer_ez.variations import MultiDSharedConvHeadAttention
54    return MultiDSharedConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
57@option(TransformerConfigs.encoder_attn, 'MultiDPHConvHeadAttention')
58def _d_per_head_conv_mha(c: TransformerConfigs):
66    from labml_nn.transformers.primer_ez.variations import MultiDPHConvHeadAttention
67    return MultiDPHConvHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
70def main():

実験を作成

72    experiment.create(name="primer_ez")

コンフィグの作成

74    conf = Configs()

オーバーライド設定

76    experiment.configs(conf, {

キャラクターレベルのトークナイザーを使う

78        'tokenizer': 'character',

プロンプトセパレータが空白

80        'prompt_separator': '',

サンプリングの開始プロンプト

82        'prompt': 'It is ',

タイニー・シェイクスピア・データセットを使う

84        'text': 'tiny_shakespeare',

コンテキストサイズを次の値にしてください

87        'seq_len': 256,

時代に合わせた列車

89        'epochs': 128,

バッチサイズ

91        'batch_size': 32,

エポックごとにトレーニングと検証を切り替える

94        'inner_iterations': 10,

モデルサイズ

97        'd_model': 512,
98        'transformer.ffn.d_ff': 2048,

Adam オプティマイザーを使う

101        'optimizer.optimizer': 'Adam',
102        'optimizer.learning_rate': 2.5e-4,
107        'transformer.ffn.activation': 'SquaredReLU',

⭐️ エンコーダアテンションにはマルチコンバーションヘッドアテンションを使用してください

mha これをオリジナルのマルチヘッドアテンション用と交換してください。

112        'transformer.encoder_attn': 'MultiDConvHeadAttention',
113    })

保存および読み込み用のモデルを設定する

116    experiment.add_pytorch_models({'model': conf.model})

実験を始める

119    with experiment.start():

トレーニングを実行

121        conf.run()

125if __name__ == '__main__':
126    main()