14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect
23class RotaryPositionalEmbeddings(nn.Module):
34 def __init__(self, d: int, base: int = 10_000):
39 super().__init__()
41 self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
x
是位于键或带有形状的查询开头的 Tensor[ batch_size, seq_len, n_heads, d]
43 def forward(self, x: torch.Tensor):
提取形状
48 batch_size, seq_len, n_heads, d = x.shape
51 d_2 = d // 2
创建头寸指数[0, 1, ..., seq_len - 1]
54 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
计算持仓指数的乘积和
57 idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)
连接这样我们就有 row
61 idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
计算
65 neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
77 rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])
80 return rx
d_model
是变压器嵌入中的特征数n_heads
是注意头的数量d_k
是每头特征的数量is_causal
表示这是否是因果关注(屏蔽)90 def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
97 super().__init__()
98
99 self.is_causal = is_causal
100 self.n_heads = n_heads
101 self.d_k = d_k
在 softmax 之前扩大注意力
104 self.scale = 1 / math.sqrt(self.d_k)
用于查询、键和值标头的线性图层。
107 self.query = nn.Linear(d_model, n_heads * d_k)
108 self.key = nn.Linear(d_model, n_heads * d_k)
109 self.value = nn.Linear(d_model, n_heads * d_k)
预先规范层。本文改为使用 rmsNorm。
112 self.norm = nn.LayerNorm(d_model)
Softmax 表示注意力概率
115 self.softmax = nn.Softmax(dim=-1)
旋转位置嵌入
118 self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)
最后的线性层
121 self.output = nn.Linear(n_heads * d_k, d_model)
123 def mask_attention(self, attn: torch.Tensor):
非因果注意没有遮罩
131 if not self.is_causal:
132 return attn
创建三角形蒙版
135 mask = torch.tril(attn.new_ones(attn.shape[-2:]))
按口罩过滤
137 return attn.masked_fill(mask == 0, float('-inf'))
h
变压器嵌入的形状是多少[batch_size, seq_len, d_model]
139 def forward(self, h: torch.Tensor):
剩余连接
145 h_res = h
规范化前
148 h = self.norm(h)
获取查询、键和值,并将它们分成头部。这些会有形状[batch_size, seq_len, n_heads, d_k]
152 mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
153 q = self.query(h).view(mh_shape)
154 k = self.key(h).view(mh_shape)
155 v = self.value(h).view(mh_shape)
应用旋转位置嵌入
158 q = self.rotary_pe(q)
159 k = self.rotary_pe(k)
计算注意力
162 attn = torch.einsum('bihd,bjhd->bhij', q, k)
按比例缩放
164 attn = attn * self.scale
如果是因果关系,请戴口罩
167 attn = self.mask_attention(attn)
计算注意力概率
170 attn = self.softmax(attn)
获取值
173 h = torch.einsum("bhij,bjhd->bihd", attn, v)
从形状改[batch_size, seq_len, n_heads, d_k]
为[batch_size, seq_len, n_heads * d_k]
177 h = h.reshape(*h.shape[:-2], -1)
应用最后的线性图层。结果将有形状[batch_size, seq_len, d_model]
181 h = self.output(h)
添加剩余连接
184 return h + h_res
这与上面定义的自我注意层类似,不同之处在于它从与查询不同的嵌入集获取键和值。
这在编码器中用于根据输入区块对检索到的区块进行编码。
我们在此处不使用任何显式的位置嵌入。我们假设模型可以在嵌入中隐式表示位置信息。
187class CrossAttention(nn.Module):
d_model
是变压器嵌入中的特征数n_heads
是注意头的数量d_k
是每头特征的数量201 def __init__(self, d_model: int, n_heads: int, d_k: int):
207 super().__init__()
208
209 self.n_heads = n_heads
210 self.d_k = d_k
在 softmax 之前扩大注意力
213 self.scale = 1 / math.sqrt(self.d_k)
用于查询、键和值标头的线性图层。
216 self.query = nn.Linear(d_model, n_heads * d_k)
217 self.key = nn.Linear(d_model, n_heads * d_k)
218 self.value = nn.Linear(d_model, n_heads * d_k)
查询嵌入的预规范层。本文改为使用 rmsNorm。
221 self.norm = nn.LayerNorm(d_model)
Softmax 表示注意力概率
224 self.softmax = nn.Softmax(dim=-1)
最后的线性层
227 self.output = nn.Linear(n_heads * d_k, d_model)
e
是带有 shape 的检索的最近邻区块嵌入[batch_size, chunks, neighbors, neighbor_len, d_model]
h
是使用 shape 从中检索最近邻域的输入块[batch_size, chunks, chunk_len, d_model]
。这已经规范化了。229 def forward(self, e: torch.Tensor, h: torch.Tensor):
剩余连接
238 e_res = e
规范化检索到的区块
241 e = self.norm(e)
从检索到的区块中获取查询
244 q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
从输入块中获取键和值
246 k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
247 v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)
计算所有区块的注意力分数。每个检索到的邻居都将注意检索到它的原始区块。这将有形状[batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
252 attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
缩放注意力分数
254 attn = attn * self.scale
计算最后一个维度的 softmax
257 attn = self.softmax(attn)
收集值
260 e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)
从形状改[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]
为[batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
264 e = e.reshape(*e.shape[:-2], -1)
应用最后的线性图层。结果将有形状[batch_size, chunks, neighbors, neighbor_len, d_model]
268 e = self.output(e)
添加剩余连接
271 return e + e_res
274class ChunkedCrossAttention(nn.Module):
d_model
是变压器嵌入中的特征数n_heads
是注意头的数量d_k
是每头特征的数量chunk_len
是区块的长度286 def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
294 super().__init__()
295
296 self.chunk_len = chunk_len
297 self.n_heads = n_heads
298 self.d_k = d_k
在 softmax 之前扩大注意力
301 self.scale = 1 / math.sqrt(self.d_k)
用于查询、键和值标头的线性图层。
304 self.query = nn.Linear(d_model, n_heads * d_k)
305 self.key = nn.Linear(d_model, n_heads * d_k)
306 self.value = nn.Linear(d_model, n_heads * d_k)
查询嵌入的预规范层。本文改为使用 rmsNorm。
309 self.norm = nn.LayerNorm(d_model)
Softmax 表示注意力概率
312 self.softmax = nn.Softmax(dim=-1)
最后的线性层
315 self.output = nn.Linear(n_heads * d_k, d_model)
h
shape 的输入嵌入[batch_size, seq_len, d_model]
e
是检索到的 shape 的最近邻值[batch_size, chunks, neighbors, neighbor_len, d_model]
317 def forward(self, h: torch.Tensor, e: torch.Tensor):
塑造身材
324 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
如果没有区块,则不注意(采样时用于短输入)
327 if chunks == 0:
328 return h
剩余连接
331 h_res = h
移除第一个chunk_len - 1
嵌入。输入只关注使用过去的令牌检索和编码的邻居;这样就不会泄露信息。也就是说,从第一个区块中检索到的邻居将获得来自第一个区块的信息。因此,通过将序列向左移动,chunk_len - 1
我们可以确保信息只向右流动。
339 h = h[:, self.chunk_len - 1:]
规范前
341 h = self.norm(h)
在末尾追加空嵌入,以便能够将输入拆分为块
343 if h.shape[1] < chunks * self.chunk_len:
344 h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
将输入重塑为块。
346 h = h.reshape(batch_size, chunks, self.chunk_len, d_model)
从输入中获取查询
349 q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
从检索到的邻居获取键和值
351 k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
352 v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)
计算输入区块的注意力分数。每个区块都将关注前一个区块检索到的邻居。这将有形状[batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
357 attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
缩放注意力分数
359 attn = attn * self.scale
在最后两个维度上应用 softmaxneighbors, neighbor_len
362 attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)
收集值
365 h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)
从形状改[batch_size, chunks, chunk_len, n_heads, d_k]
为[batch_size, chunks * chunk_len, n_heads * d_k]
369 h = h.reshape(batch_size, chunks * self.chunk_len, -1)
应用最后的线性图层。结果将有形状[batch_size, chunks * chunk_len, d_model]
373 h = self.output(h)
向左追加chunk_len - 1
零嵌入;即右移回去
376 h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)
截断并添加剩余连接
379 return h[:, :h_res.shape[1]] + h_res
382class FeedForward(nn.Module):
d_model
是变压器嵌入中的特征数d_ff
是隐藏图层中的数字要素389 def __init__(self, d_model: int, d_ff: int):
395 super().__init__()
两个线性层
398 self.lin1 = nn.Linear(d_model, d_ff)
399 self.lin2 = nn.Linear(d_ff, d_model)
ReLU 激活
402 self.act = nn.ReLU()
规范前层
405 self.norm = nn.LayerNorm(d_model)
h
是形状的嵌入[batch_size, seq_len, d_model]
407 def forward(self, h: torch.Tensor):
剩余
413 h_res = h
规范前
415 h = self.norm(h)
第一个线性层
417 h = self.lin1(h)
激活
419 h = self.act(h)
第二个线性层
421 h = self.lin2(h)
添加剩余连接
424 return h + h_res
427class NearestNeighborEncoder(nn.Module):
chunk_len
是区块的长度n_layer
是编码器中的层数ca_layers
是交叉关注的层次吗d_model
是嵌入中要素的数量n_heads
是注意层中的头部数量d_k
是注意头的大小d_ff
是前馈网络隐藏层的大小434 def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
435 d_model: int, n_heads: int, d_k: int, d_ff: int):
446 super().__init__()
447 self.ca_layers = ca_layers
448 self.chunk_len = chunk_len
交叉注意层
450 self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
双向自我关注层
452 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
前馈图层
454 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
预归一化层
457 self.norm_h = nn.LayerNorm(d_model)
e
是检索到的最近邻的令牌嵌入,形状为[batch_size, chunks, neighbors, neighbor_len, d_model]
h
is 是形状的输入令牌嵌入[batch_size, seq_len, d_model]
区块和邻居是并行处理的。
459 def forward(self, e: torch.Tensor, h: torch.Tensor):
塑造身材
472 batch_size, chunks, neighbors, neighbor_len, d_model = e.shape
475 h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)
规范前
478 h_split = self.norm_h(h_split)
保留交叉关注层的索引
481 p_ca = 0
适用于所有图层
483 for p in range(len(self.attn)):
双向自我关注
486 e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)
交叉注意如果
489 if p in self.ca_layers:
491 e = self.ca[p_ca](e, h_split)
增加交叉注意力指数
493 p_ca += 1
前馈层
496 e = self.ffw[p](e)
返回
499 return e
502class RetroModel(nn.Module):
v_vocab
是词汇表中代币的数量d_model
是嵌入中要素的数量n_layers
是解码器中的层数ca_layers
是交叉关注的层次吗chunk_len
是区块的长度n_heads
是注意层中的头部数量d_k
是注意头的大小d_ff
是前馈网络隐藏层的大小encoder
是最近邻编码器509 def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
510 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
522 super().__init__()
523
524 self.ca_layers = ca_layers
525 self.encoder = encoder
令牌嵌入层
528 self.emb = nn.Embedding(n_vocab, d_model)
分块交叉注意力层
530 self.cca = nn.ModuleList(
531 [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
注意层
533 self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
前馈图层
535 self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
读出层
537 self.read = nn.Linear(d_model, n_vocab)
最近邻嵌入的预归一化层
541 self.norm_e = nn.LayerNorm(d_model)
543 def forward(self, x: torch.Tensor, ret: torch.Tensor):
获取输入嵌入
552 h = self.emb(x)
558 ret_emb = self.emb(ret)
保留分块交叉注意层的索引
561 p_ca = 0
适用于所有图层
563 for p in range(len(self.attn)):
因果自我关注
565 h = self.attn[p](h)
在第一层之前获取编码器嵌入
569 if self.ca_layers and p == min(self.ca_layers):
573 e = self.encoder(ret_emb, h)
规范化编码器嵌入
575 e = self.norm_e(e)
大块交叉注意如果
578 if p in self.ca_layers:
580 h = self.cca[p_ca](h, e)
递增分块交叉注意力指数
582 p_ca += 1
585 h = self.ffw[p](h)
588 return self.read(h)
591def _test():
595 chunk_len = 4
596 d_model = 8
597 d_ff = 32
598 n_heads = 2
599 d_k = 4
600
601 device = torch.device('cuda:0')
602
603 m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
604 encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
605
606 m.to(device)
607 x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
608 ret = [
609 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
610 [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
611 ]
612 res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
613
614 inspect(res)
618if __name__ == '__main__':
619 _test()