12from labml import experiment
13from labml.configs import option
14from labml_nn.transformers import TransformerConfigs
15from labml_nn.transformers.configs import FeedForwardConfigs
16from labml_nn.transformers.mlm.experiment import TransformerMLM, Configs as MLMConfigs
19class Configs(MLMConfigs):
混合 MLP 配置
32@option(Configs.mix_mlp)
33def _mix_mlp_configs(c: Configs):
38 conf = FeedForwardConfigs()
MLP 的大小是序列长度,因为它跨令牌应用
40 conf.d_model = c.seq_len
该论文建议激活
42 conf.activation = 'GELU'
45 return conf
48@option(Configs.transformer)
49def _transformer_configs(c: Configs):
设置嵌入和生成 logit 的词汇量大小
58 conf.n_src_vocab = c.n_tokens
59 conf.n_tgt_vocab = c.n_tokens
嵌入大小
61 conf.d_model = c.d_model
63 from labml_nn.transformers.mlp_mixer import MLPMixer
64 conf.encoder_attn = MLPMixer(c.mix_mlp.ffn)
67 return conf
70def main():
创建实验
72 experiment.create(name="mlp_mixer_mlm")
创建配置
74 conf = Configs()
覆盖配置
76 experiment.configs(conf, {
批量大小
78 'batch_size': 64,
序列长度。我们使用较短的序列长度来更快地训练。否则,传销模型需要很长时间才能训练。
81 'seq_len': 32,
训练 1024 个时代。
84 'epochs': 1024,
在训练和验证之间切换每个纪元的次数
87 'inner_iterations': 1,
变压器配置
90 'd_model': 128,
91 'transformer.ffn.d_ff': 256,
92 'transformer.n_heads': 8,
93 'transformer.n_layers': 6,
94 'transformer.ffn.activation': 'GELU',
混音器 MLP 隐藏层大小
97 'mix_mlp.d_ff': 128,
设置用于保存和加载的模型
105 experiment.add_pytorch_models({'model': conf.model})
开始实验
108 with experiment.start():
跑步训练
110 conf.run()
114if __name__ == '__main__':
115 main()