11from typing import List
12
13import torch
14import torch.nn as nn
15from labml.logger import Text
16
17from labml import experiment, tracker, monit, logger
18from labml.configs import option
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.xl import TransformerXL, TransformerXLLayer26class AutoregressiveModel(Module):31 def __init__(self, n_vocab: int, d_model: int, transformer: TransformerXL):
32 super().__init__()令牌嵌入模块
34 self.src_embed = nn.Embedding(n_vocab, d_model)变压器
36 self.transformer = transformer最后一层
38 self.generator = nn.Linear(d_model, n_vocab)口罩
40 self.mask_x = None
41 self.mask_mem = None43 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):内存的长度
45 m_len = len(mem[0]) if mem else 0为令牌创建后续掩码
47 if self.mask_x is None or self.mask_x.shape[0] < len(x):
48 from labml_nn.transformers.utils import subsequent_mask
49 self.mask_x = subsequent_mask(len(x)).to(x.device)为内存创建一个全一(完全可见性)掩码
51 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
52 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)如果有内存,则连接掩码
55 if m_len:
56 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)否则,请使用后续的掩码
58 else:
59 mask = self.mask_x[:len(x), :len(x)]令牌嵌入
62 x = self.src_embed(x)用它穿过变压器
64 res, mem = self.transformer(x, mem, mask)生成下一个令牌的日志
66 res = self.generator(res)68 return res, mem71class Configs(NLPAutoRegressionConfigs):78 model: AutoregressiveModel令牌嵌入大小
81 d_model: int = 128注意头数量
83 heads: int = 4辍学概率
85 dropout: float = 0.0FFN 隐藏层中的要素数量
87 d_ff: int = 256变压器层数
89 n_layers: int = 6要保留的记忆数量
91 mem_len: int = 128状态模块用于在训练和验证之间切换时保持记忆
93 memory = SimpleStateModule()95 def init(self):设置跟踪器配置
97 tracker.set_scalar("accuracy.*", True)
98 tracker.set_scalar("loss.*", True)向日志模块输出添加钩子
100 hook_model_outputs(self.mode, self.model, 'model')这将使精度指标统计数据和记忆分开,以便训练和验证。
102 self.state_modules = [self.accuracy, self.memory]连接记忆并删除旧记忆以最大限度地保留内mem_len
存。
104 def merge_memory(self, old_mem, new_mem):如果配置为不使用内存
111 if self.mem_len == 0:
112 return []与旧内存串联
115 if old_mem:
116 mem = [torch.cat((m, x), dim=0) for m, x in zip(old_mem, new_mem)]
117 else:
118 mem = new_mem截断旧的记忆
121 if len(mem[0]) > self.mem_len:
122 mem = [m[-self.mem_len:] for m in mem]125 return mem127 def step(self, batch: any, batch_idx: BatchIndex):将数据移动到设备
133 data, target = batch[0].to(self.device), batch[1].to(self.device)在训练模式下更新全局步长(处理的令牌数)
136 if self.mode.is_train:
137 tracker.add_global_step(data.shape[0] * data.shape[1])是否捕获模型输出
140 with self.mode.update(is_log_activations=batch_idx.is_last):获得回忆
142 mem = self.memory.get()运行模型
144 output, new_mem = self.model(data, mem)合并内存
146 mem = self.merge_memory(mem, new_mem)更新记忆
148 self.memory.set(mem)计算和记录交叉熵损失
151 loss = self.loss_func(output, target)
152 tracker.add("loss.", loss)计算和记录精度
155 self.accuracy(output, target)
156 self.accuracy.track()训练模型
159 if self.mode.is_train:计算梯度
161 loss.backward()剪辑渐变
163 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)采取优化器步骤
165 self.optimizer.step()记录每个纪元最后一批的模型参数和梯度
167 if batch_idx.is_last:
168 tracker.add('model', self.model)清除渐变
170 self.optimizer.zero_grad()保存跟踪的指标
173 tracker.save()175 def sample(self):启动提示
181 prompt = self.prompt收集输出以进行打印
183 log = [(prompt, Text.subtle)]记忆
185 mem = []样本 25 个代币
187 for i in monit.iterate('Sample', 25):将提示符号化
189 data = self.text.text_to_i(prompt).unsqueeze(-1)移至设备
191 data = data.to(self.device)获取模型输出
193 output, new_mem = self.model(data, mem)获取模型预测(贪婪)
195 output = output.argmax(dim=-1).squeeze(1)将预测添加到提示符中
197 prompt += self.prompt_separator + self.text.itos[output[-1]]在下一次迭代中只喂最后一个角色进行建模,其余部分将作为记忆进去
199 prompt = prompt[-1:]添加日志记录的预测
201 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]更新内存
203 mem = self.merge_memory(mem, new_mem)打印采样输出
206 logger.log(log)209@option(Configs.model)
210def autoregressive_model(c: Configs):214 from labml_nn.transformers.xl import RelativeMultiHeadAttention
215 from labml_nn.transformers.feed_forward import FeedForward
216 m = AutoregressiveModel(c.n_tokens, c.d_model, TransformerXL(
217 TransformerXLLayer(d_model=c.d_model,
218 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
219 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
220 dropout_prob=c.dropout), c.n_layers))
221 return m.to(c.device)224def main():创建实验
229 experiment.create(name="transformer_xl", comment='')创建配置
231 conf = Configs()装载配置
233 experiment.configs(conf,要覆盖的配置字典
235 {'tokenizer': 'character',
236 'text': 'tiny_shakespeare',
237 'optimizer.learning_rate': 1.,
238 'optimizer.optimizer': 'Noam',
239 'prompt': 'It is',
240 'prompt_separator': '',
241
242 'train_loader': 'sequential_train_loader',
243 'valid_loader': 'sequential_valid_loader',
244
245 'seq_len': 2,
246 'mem_len': 32,
247 'epochs': 128,
248 'batch_size': 32,
249 'inner_iterations': 25,
250 })设置用于保存和加载的模型
253 experiment.add_pytorch_models({'model': conf.model})开始实验
256 with experiment.start():TrainValidConfigs.run
258 conf.run()262if __name__ == '__main__':
263 main()