11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer
26class AutoregressiveModel(Module):
31 def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32 super().__init__()
令牌嵌入模块
34 self.src_embed = nn.Embedding(n_vocab, d_model)
变压器
36 self.transformer = transformer
最后一层
38 self.generator = nn.Linear(d_model, n_vocab)
口罩
40 self.mask_x = None
41 self.mask_mem = None
43 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
内存的长度
45 m_len = len(mem[0]) if mem else 0
为令牌创建后续掩码
47 if self.mask_x is None or self.mask_x.shape[0] < len(x):
48 from labml_nn.transformers.utils import subsequent_mask
49 self.mask_x = subsequent_mask(len(x)).to(x.device)
为内存创建一个全一(完全可见性)掩码
51 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
如果有内存,则连接掩码
55 if m_len:
56 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
否则,请使用后续的掩码
58 else:
59 mask = self.mask_x[:len(x), :len(x)]
令牌嵌入
62 x = self.src_embed(x)
用它穿过变压器
64 res, mem = self.transformer(x, mem, mask)
生成下一个令牌的日志
66 res = self.generator(res)
68 return res, mem
71class Configs(NLPAutoRegressionConfigs):
78 model: AutoregressiveModel
令牌嵌入大小
81 d_model: int = 128
注意头数量
83 heads: int = 4
辍学概率
85 dropout: float = 0.0
FFN 隐藏层中的要素数量
87 d_ff: int = 256
变压器层数
89 n_layers: int = 6
要保留的记忆数量
91 mem_len: int = 128
状态模块用于在训练和验证之间切换时保持记忆
93 memory = SimpleStateModule()
95 def init(self):
设置跟踪器配置
97 tracker.set_scalar("accuracy.*", True)
98 tracker.set_scalar("loss.*", True)
向日志模块输出添加钩子
100 hook_model_outputs(self.mode, self.model, 'model')
这将使精度指标统计数据和记忆分开,以便训练和验证。
102 self.state_modules = [self.accuracy, self.memory]
连接记忆并删除旧记忆以最大限度地保留内mem_len
存。
104 def merge_memory(self, old_mem, new_mem):
如果配置为不使用内存
111 if self.mem_len == 0:
112 return []
与旧内存串联
115 if old_mem:
116 mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117 else:
118 mem = new_mem
截断旧的记忆
121 if len(mem[0]) > self.mem_len:
122 mem = [m[-self.mem_len:] for m in mem]
125 return mem
127 def step(self, batch: any, batch_idx: BatchIndex):
将数据移动到设备
133 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的令牌数)
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])
是否捕获模型输出
140 with self.mode.update(is_log_activations=batch_idx.is_last):
获得回忆
142 mem = self.memory.get()
运行模型
144 output, new_mem = self.model(data, mem)
合并内存
146 mem = self.merge_memory(mem, new_mem)
更新记忆
148 self.memory.set(mem)
计算和记录交叉熵损失
151 loss = self.loss_func(output, target)
152 tracker.add("loss.", loss)
计算和记录精度
155 self.accuracy(output, target)
156 self.accuracy.track()
训练模型
159 if self.mode.is_train:
计算梯度
161 loss.backward()
剪辑渐变
163 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
165 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
167 if batch_idx.is_last:
168 tracker.add('model', self.model)
清除渐变
170 self.optimizer.zero_grad()
保存跟踪的指标
173 tracker.save()
175 def sample(self):
启动提示
181 prompt = self.prompt
收集输出以进行打印
183 log = [(prompt, Text.subtle)]
记忆
185 mem = []
样本 25 个代币
187 for i in monit.iterate('Sample', 25):
将提示符号化
189 data = self.text.text_to_i(prompt).unsqueeze(-1)
移至设备
191 data = data.to(self.device)
获取模型输出
193 output, new_mem = self.model(data, mem)
获取模型预测(贪婪)
195 output = output.argmax(dim=-1).squeeze(1)
将预测添加到提示符中
197 prompt += self.prompt_separator + self.text.itos[output[-1]]
在下一次迭代中只喂最后一个角色进行建模,其余部分将作为记忆进去
199 prompt = prompt[-1:]
添加日志记录的预测
201 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
更新内存
203 mem = self.merge_memory(mem, new_mem)
打印采样输出
206 logger.log(log)
209@option(Configs.model)
210def autoregressive_model(c: Configs):
214 from labml_nn.transformers.xl import RelativeMultiHeadAttention
215 from labml_nn.transformers.feed_forward import FeedForward
216 m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217 TransformerXLLayer(d_model=c.d_model,
218 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220 dropout_prob=c.dropout), c.n_layers))
221 return m.to(c.device)
224def main():
创建实验
229 experiment.create(name="transformer_xl", comment='')
创建配置
231 conf = Configs()
装载配置
233 experiment.configs(conf,
要覆盖的配置字典
235 {'tokenizer': 'character',
236 'text': 'tiny_shakespeare',
237 'optimizer.learning_rate': 1.,
238 'optimizer.optimizer': 'Noam',
239 'prompt': 'It is',
240 'prompt_separator': '',
241
242 'train_loader': 'sequential_train_loader',
243 'valid_loader': 'sequential_valid_loader',
244
245 'seq_len': 2,
246 'mem_len': 32,
247 'epochs': 128,
248 'batch_size': 32,
249 'inner_iterations': 25,
250 })
设置用于保存和加载的模型
253 experiment.add_pytorch_models({'model': conf.model})
开始实验
256 with experiment.start():
TrainValidConfigs.run
258 conf.run()
262if __name__ == '__main__':
263 main()