11from typing import List, Tuple, NamedTuple
12
13import torch
14import torch.nn as nn
15
16from labml import experiment, tracker, monit, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.simple_state import SimpleStateModule
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex, hook_model_outputs
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers.compressive import CompressiveTransformer, AttentionReconstructionLoss, \
24 CompressiveTransformerLayer, Conv1dCompression
27class CompressedMemory(NamedTuple):
28 mem: List[torch.Tensor]
29 c_mem: List[torch.Tensor]
32class AutoregressiveModel(Module):
37 def __init__(self, n_vocab: int, d_model: int, transformer: CompressiveTransformer):
38 super().__init__()
令牌嵌入模块
40 self.src_embed = nn.Embedding(n_vocab, d_model)
变压器
42 self.transformer = transformer
最后一层
44 self.generator = nn.Linear(d_model, n_vocab)
口罩
46 self.mask_x = None
47 self.mask_mem = None
49 def forward(self, x: torch.Tensor, mem: CompressedMemory):
获取内存和压缩内存
51 if mem is not None:
52 mem, c_mem = mem.mem, mem.c_mem
53 else:
54 mem = []
55 c_mem = []
内存和压缩内存的总长度(用于掩码)
58 m_len = len(mem[0]) if mem else 0
59 if c_mem:
60 m_len += len(c_mem[0])
为令牌创建后续掩码
63 if self.mask_x is None or self.mask_x.shape[0] < len(x):
64 from labml_nn.transformers.utils import subsequent_mask
65 self.mask_x = subsequent_mask(len(x)).to(x.device)
为内存创建一个全一(完全可见性)掩码
67 if self.mask_mem is None or self.mask_mem.shape[1] < m_len or self.mask_mem.shape[0] < len(x):
68 self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)
如果有内存,则连接掩码
71 if m_len:
72 mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1)
否则,仅使用后续的掩码
74 else:
75 mask = self.mask_x[:len(x), :len(x)]
令牌嵌入
78 x = self.src_embed(x)
用它穿过变压器
80 res, mem = self.transformer(x, mem, c_mem, mask)
生成下一个令牌的日志
82 res = self.generator(res)
84 return res, mem
87class Configs(NLPAutoRegressionConfigs):
94 model: AutoregressiveModel
令牌嵌入大小
97 d_model: int = 128
注意头数量
99 heads: int = 4
辍学概率
101 dropout: float = 0.0
FFN 隐藏层中的要素数量
103 d_ff: int = 256
变压器层数
105 n_layers: int = 6
要保留的记忆数量
107 mem_len: int = 8
状态模块用于在训练和验证之间切换时保持记忆
109 memory = SimpleStateModule()
注意力重建损失
111 attention_reconstruction_loss: AttentionReconstructionLoss
压缩率
113 compression_rate: int = 4
压缩的内存长度
115 c_mem_len: int = 128
117 def init(self):
设置跟踪器配置
119 tracker.set_scalar("accuracy.*", True)
120 tracker.set_scalar("loss.*", True)
不要在终端中打印注意力重建损失
122 tracker.set_scalar("ar_loss.*", False)
向日志模块输出添加钩子
124 hook_model_outputs(self.mode, self.model, 'model')
这将使精度指标统计数据和记忆分开,以便训练和验证。
126 self.state_modules = [self.accuracy, self.memory]
连接新记忆并压缩最古老的记忆。
128 @torch.no_grad()
129 def merge_compress_memory(self, mem: CompressedMemory, new_mem: List[torch.Tensor]) \
130 -> Tuple[CompressedMemory, List[torch.Tensor]]:
如果配置指定不使用内存
136 if self.mem_len == 0 and self.c_mem_len == 0:
137 return CompressedMemory([], []), []
获取内存和压缩内存
140 if mem is not None:
141 mem, c_mem = mem.mem, mem.c_mem
142 else:
143 mem, c_mem = [], []
将新记忆与旧记忆连接起来
146 if mem:
147 mem = [torch.cat((m, x), dim=0) for m, x in zip(mem, new_mem)]
148 else:
149 mem = new_mem
如果记忆多于最早的记忆,则压缩最早的记忆mem_len
152 if len(mem[0]) > self.mem_len:
计算要制作的压缩记忆的数量,其中是我们拥有的记忆数量,是我们维护的最大记忆数(mem_len
)。
156 n_c_mem = (len(mem[0]) - self.mem_len + self.compression_rate - 1) // self.compression_rate
要压缩的内存数量
158 n_old = n_c_mem * self.compression_rate
用于保存每层需要压缩的内存的列表。
160 mem_to_compress = []
一个列表,用于保存每层未被压缩的记忆。
162 uncompressed_mem = []
遍历每层的记忆。
164 for m in mem:
在以下位置拆分记忆
166 cm, m = torch.split(m, [n_old, len(m) - n_old])
收集记忆进行压缩
168 mem_to_compress.append(cm)
收集剩余的记忆
170 uncompressed_mem.append(m)
更新记忆
172 mem = uncompressed_mem
压缩记忆
175 new_c_mem = []
176 for i, layer in enumerate(self.model.transformer.layers):
177 new_c_mem.append(layer.compress(mem_to_compress[i]))
将新压缩的存储器与旧的压缩存储器连接起来
180 if c_mem:
181 c_mem = [torch.cat((m, nm), dim=0) for m, nm in zip(c_mem, new_c_mem)]
如果没有旧的压缩记忆
183 else:
184 c_mem = new_c_mem
截断旧的记忆
187 if len(c_mem[0]) > self.c_mem_len:
188 c_mem = [m[-self.c_mem_len:] for m in c_mem]
如果内存数量少于mem_len
190 else:
191 mem_to_compress = []
返回被压缩的记忆和记忆。重建损失计算需要被压缩的记忆。
195 return CompressedMemory(mem, c_mem), mem_to_compress
197 def step(self, batch: any, batch_idx: BatchIndex):
将数据移动到设备
203 data, target = batch[0].to(self.device), batch[1].to(self.device)
在训练模式下更新全局步长(处理的令牌数)
206 if self.mode.is_train:
207 tracker.add_global_step(data.shape[0] * data.shape[1])
是否捕获模型输出
210 with self.mode.update(is_log_activations=batch_idx.is_last):
获得回忆
212 mem = self.memory.get()
运行模型
214 output, new_mem = self.model(data, mem)
合并和压缩内存
216 mem, mem_to_compress = self.merge_compress_memory(mem, new_mem)
更新记忆
218 self.memory.set(mem)
计算和记录交叉熵损失
221 loss = self.loss_func(output, target)
222 tracker.add("loss.", loss)
如果在此步骤中记忆被压缩,则计算注意力重建损失
225 if mem_to_compress:
引起注意重建损失
227 ar_loss = self.attention_reconstruction_loss(new_mem, mem_to_compress)
追踪注意力重建损失
229 tracker.add("ar_loss.", ar_loss)
将注意力重建损失增加到损失
231 loss = loss + ar_loss
计算和记录精度
234 self.accuracy(output, target)
235 self.accuracy.track()
训练模型
238 if self.mode.is_train:
计算梯度
240 loss.backward()
剪辑渐变
242 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
采取优化器步骤
244 self.optimizer.step()
记录每个纪元最后一批的模型参数和梯度
246 if batch_idx.is_last:
247 tracker.add('model', self.model)
清除渐变
249 self.optimizer.zero_grad()
保存跟踪的指标
252 tracker.save()
254 def sample(self):
启动提示
260 prompt = self.prompt
收集输出以进行打印
262 log = [(prompt, Text.subtle)]
记忆
264 mem = CompressedMemory([], [])
样本 25 个代币
266 for i in monit.iterate('Sample', 25):
将提示符号化
268 data = self.text.text_to_i(prompt).unsqueeze(-1)
移至设备
270 data = data.to(self.device)
获取模型输出
272 output, new_mem = self.model(data, mem)
获取模型预测(贪婪)
274 output = output.argmax(dim=-1).squeeze(1)
将预测添加到提示符中
276 prompt += self.prompt_separator + self.text.itos[output[-1]]
在下一次迭代中只喂最后一个角色进行建模,其余部分将作为记忆进去
278 prompt = prompt[-1:]
添加日志记录的预测
280 log += [(self.prompt_separator + self.text.itos[output[-1]], Text.value)]
更新和压缩内存
282 mem, _ = self.merge_compress_memory(mem, new_mem)
打印采样输出
285 logger.log(log)
288@option(Configs.model)
289def autoregressive_model(c: Configs):
293 from labml_nn.transformers.xl import RelativeMultiHeadAttention
294 from labml_nn.transformers.feed_forward import FeedForward
295 m = AutoregressiveModel(c.n_tokens, c.d_model, CompressiveTransformer(
296 CompressiveTransformerLayer(d_model=c.d_model,
297 self_attn=RelativeMultiHeadAttention(c.heads, c.d_model, c.dropout),
298 feed_forward=FeedForward(c.d_model, c.d_ff, c.dropout),
299 dropout_prob=c.dropout,
300 compress=Conv1dCompression(c.compression_rate, c.d_model)), c.n_layers))
301 return m.to(c.device)
304@option(Configs.attention_reconstruction_loss)
305def attention_reconstruction_loss(c: Configs):
309 return AttentionReconstructionLoss(c.model.transformer.layers)
312def main():
创建实验
317 experiment.create(name="compressive_transformer", comment='')
创建配置
319 conf = Configs()
装载配置
321 experiment.configs(conf,
要覆盖的配置字典
323 {'tokenizer': 'character',
324 'text': 'tiny_shakespeare',
325 'optimizer.learning_rate': 2.5e-4,
326 'optimizer.optimizer': 'AdamW',
327 'prompt': 'It is',
328 'prompt_separator': '',
329
330 'train_loader': 'sequential_train_loader',
331 'valid_loader': 'sequential_valid_loader',
332
333 'seq_len': 8,
334 'mem_len': 8,
335 'epochs': 128,
336 'batch_size': 32,
337 'inner_iterations': 25,
338 'compression_rate': 2,
339 })
设置用于保存和加载的模型
342 experiment.add_pytorch_models({'model': conf.model})
开始实验
345 with experiment.start():
TrainValidConfigs.run
347 conf.run()
351if __name__ == '__main__':
352 main()