11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM
28class TransformerMLM(nn.Module):
encoder
是变压器编码器src_embed
是令牌嵌入模块(带有位置编码)generator
是给出 logit 的最后一个完全连接的层。33 def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40 super().__init__()
41 self.generator = generator
42 self.src_embed = src_embed
43 self.encoder = encoder
45 def forward(self, x: torch.Tensor):
使用位置编码获取令牌嵌入
47 x = self.src_embed(x)
变压器编码
49 x = self.encoder(x, None)
输出的对数
51 y = self.generator(x)
返回结果(第二个值用于状态,因为我们的训练器也与 RNN 一起使用)
55 return y, None
58class Configs(NLPAutoRegressionConfigs):
传销模型
69 model: TransformerMLM
变压器
71 transformer: TransformerConfigs
代币数量
74 n_tokens: int = 'n_tokens_mlm'
不应该被掩盖的代币
76 no_mask_tokens: List[int] = []
掩盖代币的概率
78 masking_prob: float = 0.15
用随机令牌替换掩码的概率
80 randomize_prob: float = 0.1
用原始令牌替换掩码的概率
82 no_change_prob: float = 0.1
84 mlm: MLM
[MASK]
令牌
87 mask_token: int
[PADDING]
令牌
89 padding_token: int
提示采样
92 prompt: str = [
93 "We are accounted poor citizens, the patricians good.",
94 "What authority surfeits on would relieve us: if they",
95 "would yield us but the superfluity, while it were",
96 "wholesome, we might guess they relieved us humanely;",
97 "but they think we are too dear: the leanness that",
98 "afflicts us, the object of our misery, is as an",
99 "inventory to particularise their abundance; our",
100 "sufferance is a gain to them Let us revenge this with",
101 "our pikes, ere we become rakes: for the gods know I",
102 "speak this in hunger for bread, not in thirst for revenge.",
103 ]
105 def init(self):
[MASK]
令牌
111 self.mask_token = self.n_tokens - 1
[PAD]
令牌
113 self.padding_token = self.n_tokens - 2
116 self.mlm = MLM(padding_token=self.padding_token,
117 mask_token=self.mask_token,
118 no_mask_tokens=self.no_mask_tokens,
119 n_tokens=self.n_tokens,
120 masking_prob=self.masking_prob,
121 randomize_prob=self.randomize_prob,
122 no_change_prob=self.no_change_prob)
精度度量度(忽略等于的标签[PAD]
)
125 self.accuracy = Accuracy(ignore_index=self.padding_token)
交叉熵损失(忽略等于的标签[PAD]
)
127 self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
129 super().init()
131 def step(self, batch: any, batch_idx: BatchIndex):
将输入移至设备
137 data = batch[0].to(self.device)
在训练模式下更新全局步长(处理的令牌数)
140 if self.mode.is_train:
141 tracker.add_global_step(data.shape[0] * data.shape[1])
获取屏蔽的输入和标签
144 with torch.no_grad():
145 data, labels = self.mlm(data)
是否捕获模型输出
148 with self.mode.update(is_log_activations=batch_idx.is_last):
获取模型输出。它在使用 RNN 时返回状态的元组。这尚未实现。
152 output, *_ = self.model(data)
计算并记录损失
155 loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156 tracker.add("loss.", loss)
计算和记录精度
159 self.accuracy(output, labels)
160 self.accuracy.track()
训练模型
163 if self.mode.is_train:
计算梯度
165 loss.backward()
剪辑渐变
167 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
169 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
171 if batch_idx.is_last:
172 tracker.add('model', self.model)
清除渐变
174 self.optimizer.zero_grad()
保存跟踪的指标
177 tracker.save()
179 @torch.no_grad()
180 def sample(self):
填充的数据为空张量[PAD]
。
186 data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)
逐个添加提示
188 for i, p in enumerate(self.prompt):
获取代币索引
190 d = self.text.text_to_i(p)
添加到张量中
192 s = min(self.seq_len, len(d))
193 data[:s, i] = d[:s]
将张量移到当前设备
195 data = data.to(self.device)
获取屏蔽的输入和标签
198 data, labels = self.mlm(data)
获取模型输出
200 output, *_ = self.model(data)
打印生成的样本
203 for j in range(data.shape[1]):
从打印中收集输出
205 log = []
对于每个代币
207 for i in range(len(data)):
如果标签不是[PAD]
209 if labels[i, j] != self.padding_token:
获取预测
211 t = output[i, j].argmax().item()
如果是可打印的字符
213 if t < len(self.text.itos):
正确的预测
215 if t == labels[i, j]:
216 log.append((self.text.itos[t], Text.value))
预测不正确
218 else:
219 log.append((self.text.itos[t], Text.danger))
如果它不是可打印的字符
221 else:
222 log.append(('*', Text.danger))
如果标签是[PAD]
(未遮罩),请打印原件。
224 elif data[i, j] < len(self.text.itos):
225 log.append((self.text.itos[data[i, j]], Text.subtle))
打印
228 logger.log(log)
包括[PAD]
和在内的代币数量[MASK]
231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):
236 return c.text.n_tokens + 2
239@option(Configs.transformer)
240def _transformer_configs(c: Configs):
设置嵌入和生成 logit 的词汇量大小
249 conf.n_src_vocab = c.n_tokens
250 conf.n_tgt_vocab = c.n_tokens
嵌入大小
252 conf.d_model = c.d_model
255 return conf
创建分类模型
258@option(Configs.model)
259def _model(c: Configs):
263 m = TransformerMLM(encoder=c.transformer.encoder,
264 src_embed=c.transformer.src_embed,
265 generator=c.transformer.generator).to(c.device)
266
267 return m
270def main():
创建实验
272 experiment.create(name="mlm")
创建配置
274 conf = Configs()
覆盖配置
276 experiment.configs(conf, {
批量大小
278 'batch_size': 64,
序列长度。我们使用较短的序列长度来更快地训练。否则训练需要很长时间。
281 'seq_len': 32,
训练 1024 个时代。
284 'epochs': 1024,
在训练和验证之间切换每个纪元的次数
287 'inner_iterations': 1,
变压器配置(与默认值相同)
290 'd_model': 128,
291 'transformer.ffn.d_ff': 256,
292 'transformer.n_heads': 8,
293 'transformer.n_layers': 6,
设置用于保存和加载的模型
301 experiment.add_pytorch_models({'model': conf.model})
开始实验
304 with experiment.start():
跑步训练
306 conf.run()
310if __name__ == '__main__':
311 main()