16import torch
17from torch import nn
18
19from labml import experiment
20from labml.configs import option
21from labml.utils.pytorch import get_modules
22from labml_helpers.module import Module
23from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
26class AutoregressiveModel(Module):
31 def __init__(self, n_vocab: int, d_model: int, transformer: Module):
32 super().__init__()
令牌嵌入模块
34 self.src_embed = nn.Embedding(n_vocab, d_model)
35 self.transformer = transformer
36 self.generator = nn.Linear(d_model, n_vocab)
38 def forward(self, x: torch.Tensor):
嵌入代币
40 x = self.src_embed(x)
用它穿过变压器
42 res = self.transformer(x)
生成下一个令牌的日志
44 return self.generator(res), None
47class Configs(NLPAutoRegressionConfigs):
54 model: AutoregressiveModel
55
56 d_model: int = 512
57 nu: int = 1
58 heads: int = 8
59 dropout: float = 0.0
60 d_ff: int = 2048
61 n_layers: int = 6
69 from labml_nn.transformers.fast_weights import FastWeightsAttentionTransformer, \
70 FastWeightsAttentionTransformerLayer, FastWeightsAttention, FeedForward
71
72 from labml_nn.transformers.fast_weights import DPFP
73 return AutoregressiveModel(
74 c.n_tokens, c.d_model,
75 FastWeightsAttentionTransformer(
76 FastWeightsAttentionTransformerLayer(d_model=c.d_model,
77 attn=FastWeightsAttention(c.heads, c.d_model, c.dropout, DPFP(nu=c.nu)),
78 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
79 dropout_prob=c.dropout),
80 c.n_layers)).to(c.device)
83def main():
创建实验
85 experiment.create(name="fast_weights_transformer")
创建配置
87 conf = Configs()
装载配置
89 experiment.configs(conf,
要覆盖的配置字典
91 {'tokenizer': 'character',
92 'text': 'tiny_shakespeare',
93 'optimizer.learning_rate': 1.0,
94 'optimizer.optimizer': 'Noam',
95 'prompt': 'It is',
96 'prompt_separator': '',
97
98 'train_loader': 'shuffled_train_loader',
99 'valid_loader': 'shuffled_valid_loader',
100
101 'seq_len': 128,
102 'epochs': 128,
103 'batch_size': 16,
104 'inner_iterations': 25})
设置用于保存和加载的模型
107 experiment.add_pytorch_models(get_modules(conf))
开始实验
110 with experiment.start():
运行训练循环
112 conf.run()
113
114
115if __name__ == '__main__':
116 main()