复古模型

这是 RETRO 的模型定义

14import math
15from typing import Set
16
17import torch
18from torch import nn
19
20from labml.logger import inspect

绳索嵌入

我们在自我注意力层中使用旋转位置嵌入。我们假设位置信息被嵌入到嵌入中,因此不会在因果关注中使用它们。非因果的自我注意力需要明确的位置信息,因为它无法推断出来

23class RotaryPositionalEmbeddings(nn.Module):
  • d 是要素的数量
  • base 是用于计算的常数
  • 34    def __init__(self, d: int, base: int = 10_000):
    39        super().__init__()

    41        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)
    • x 是位于键或带有形状的查询开头的 Tensor[ batch_size, seq_len, n_heads, d]
    43    def forward(self, x: torch.Tensor):

    提取形状

    48        batch_size, seq_len, n_heads, d = x.shape

    51        d_2 = d // 2

    创建头寸指数[0, 1, ..., seq_len - 1]

    54        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

    计算持仓指数的乘积和

    57        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

    连接这样我们就有 row

    61        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

    计算

    65        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

    计算

    对于

    77        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

    80        return rx

    自我注意层

    这适用于因果和非因果的多头自我关注

    83class SelfAttention(nn.Module):
    • d_model 是变压器嵌入中的特征数
    • n_heads 是注意头的数量
    • d_k 是每头特征的数量
    • is_causal 表示这是否是因果关注(屏蔽)
    90    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
    97        super().__init__()
    98
    99        self.is_causal = is_causal
    100        self.n_heads = n_heads
    101        self.d_k = d_k

    在 softmax 之前扩大注意力

    104        self.scale = 1 / math.sqrt(self.d_k)

    用于查询、键和值标头的线性图层。

    107        self.query = nn.Linear(d_model, n_heads * d_k)
    108        self.key = nn.Linear(d_model, n_heads * d_k)
    109        self.value = nn.Linear(d_model, n_heads * d_k)

    预先规范层。本文改为使用 rmsNorm。

    112        self.norm = nn.LayerNorm(d_model)

    Softmax 表示注意力概率

    115        self.softmax = nn.Softmax(dim=-1)

    旋转位置嵌入

    118        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

    最后的线性层

    121        self.output = nn.Linear(n_heads * d_k, d_model)

    遮住注意层以获得因果关注

    • attn 是形状的注意力矩阵[batch_size, n_heads, seq_len, seq_len]
    123    def mask_attention(self, attn: torch.Tensor):

    非因果注意没有遮罩

    131        if not self.is_causal:
    132            return attn

    创建三角形蒙版

    135        mask = torch.tril(attn.new_ones(attn.shape[-2:]))

    按口罩过滤

    137        return attn.masked_fill(mask == 0, float('-inf'))
    • h 变压器嵌入的形状是多少[batch_size, seq_len, d_model]
    139    def forward(self, h: torch.Tensor):

    剩余连接

    145        h_res = h

    规范化前

    148        h = self.norm(h)

    获取查询、键和值,并将它们分成头部。这些会有形状[batch_size, seq_len, n_heads, d_k]

    152        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
    153        q = self.query(h).view(mh_shape)
    154        k = self.key(h).view(mh_shape)
    155        v = self.value(h).view(mh_shape)

    应用旋转位置嵌入

    158        q = self.rotary_pe(q)
    159        k = self.rotary_pe(k)

    计算注意力

    162        attn = torch.einsum('bihd,bjhd->bhij', q, k)

    按比例缩放

    164        attn = attn * self.scale

    如果是因果关系,请戴口罩

    167        attn = self.mask_attention(attn)

    计算注意力概率

    170        attn = self.softmax(attn)

    获取值

    173        h = torch.einsum("bhij,bjhd->bihd", attn, v)

    从形状改[batch_size, seq_len, n_heads, d_k][batch_size, seq_len, n_heads * d_k]

    177        h = h.reshape(*h.shape[:-2], -1)

    应用最后的线性图层。结果将有形状[batch_size, seq_len, d_model]

    181        h = self.output(h)

    添加剩余连接

    184        return h + h_res

    交叉注意力层

    这与上面定义的自我注意层类似,不同之处在于它从与查询不同的嵌入集获取键和值。

    这在编码器中用于根据输入区块对检索到的区块进行编码。

    我们在此处不使用任何显式的位置嵌入。我们假设模型可以在嵌入中隐式表示位置信息。

    187class CrossAttention(nn.Module):
    • d_model 是变压器嵌入中的特征数
    • n_heads 是注意头的数量
    • d_k 是每头特征的数量
    201    def __init__(self, d_model: int, n_heads: int, d_k: int):
    207        super().__init__()
    208
    209        self.n_heads = n_heads
    210        self.d_k = d_k

    在 softmax 之前扩大注意力

    213        self.scale = 1 / math.sqrt(self.d_k)

    用于查询、键和值标头的线性图层。

    216        self.query = nn.Linear(d_model, n_heads * d_k)
    217        self.key = nn.Linear(d_model, n_heads * d_k)
    218        self.value = nn.Linear(d_model, n_heads * d_k)

    查询嵌入的预规范层。本文改为使用 rmsNorm。

    221        self.norm = nn.LayerNorm(d_model)

    Softmax 表示注意力概率

    224        self.softmax = nn.Softmax(dim=-1)

    最后的线性层

    227        self.output = nn.Linear(n_heads * d_k, d_model)
    • e 是带有 shape 的检索的最近邻区块嵌入[batch_size, chunks, neighbors, neighbor_len, d_model]
    • h 是使用 shape 从中检索最近邻域的输入块[batch_size, chunks, chunk_len, d_model] 。这已经规范化了。
    229    def forward(self, e: torch.Tensor, h: torch.Tensor):

    剩余连接

    238        e_res = e

    规范化检索到的区块

    241        e = self.norm(e)

    从检索到的区块中获取查询

    244        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)

    从输入块中获取键和值

    246        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
    247        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

    计算所有区块的注意力分数。每个检索到的邻居都将注意检索到它的原始区块。这将有形状[batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

    252        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)

    缩放注意力分数

    254        attn = attn * self.scale

    计算最后一个维度的 softmax

    257        attn = self.softmax(attn)

    收集值

    260        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

    从形状改[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k][batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

    264        e = e.reshape(*e.shape[:-2], -1)

    应用最后的线性图层。结果将有形状[batch_size, chunks, neighbors, neighbor_len, d_model]

    268        e = self.output(e)

    添加剩余连接

    271        return e + e_res

    分块交叉注意力层

    这与上面定义的交叉注意力层类似。

    这在解码器中用于关注检索到的邻居块。

    我们在此处不使用任何显式的位置嵌入。我们假设模型可以在嵌入中隐式表示位置信息。

    274class ChunkedCrossAttention(nn.Module):
    • d_model 是变压器嵌入中的特征数
    • n_heads 是注意头的数量
    • d_k 是每头特征的数量
    • chunk_len 是区块的长度
    286    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
    294        super().__init__()
    295
    296        self.chunk_len = chunk_len
    297        self.n_heads = n_heads
    298        self.d_k = d_k

    在 softmax 之前扩大注意力

    301        self.scale = 1 / math.sqrt(self.d_k)

    用于查询、键和值标头的线性图层。

    304        self.query = nn.Linear(d_model, n_heads * d_k)
    305        self.key = nn.Linear(d_model, n_heads * d_k)
    306        self.value = nn.Linear(d_model, n_heads * d_k)

    查询嵌入的预规范层。本文改为使用 rmsNorm。

    309        self.norm = nn.LayerNorm(d_model)

    Softmax 表示注意力概率

    312        self.softmax = nn.Softmax(dim=-1)

    最后的线性层

    315        self.output = nn.Linear(n_heads * d_k, d_model)

    h shape 的输入嵌入[batch_size, seq_len, d_model] e 是检索到的 shape 的最近邻值[batch_size, chunks, neighbors, neighbor_len, d_model]

    317    def forward(self, h: torch.Tensor, e: torch.Tensor):

    塑造身材

    324        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

    如果没有区块,则不注意(采样时用于短输入)

    327        if chunks == 0:
    328            return h

    剩余连接

    331        h_res = h

    移除第一个chunk_len - 1 嵌入。输入只关注使用过去的令牌检索和编码的邻居;这样就不会泄露信息。也就是说,从第一个区块中检索到的邻居将获得来自第一个区块的信息。因此,通过将序列向左移动,chunk_len - 1 我们可以确保信息只向右流动。

    339        h = h[:, self.chunk_len - 1:]

    规范前

    341        h = self.norm(h)

    在末尾追加空嵌入,以便能够将输入拆分为块

    343        if h.shape[1] < chunks * self.chunk_len:
    344            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)

    将输入重塑为块。

    346        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

    从输入中获取查询

    349        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)

    从检索到的邻居获取键和值

    351        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
    352        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

    计算输入区块的注意力分数。每个区块都将关注前一个区块检索到的邻居。这将有形状[batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

    357        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)

    缩放注意力分数

    359        attn = attn * self.scale

    在最后两个维度上应用 softmaxneighbors, neighbor_len

    362        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

    收集值

    365        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

    从形状改[batch_size, chunks, chunk_len, n_heads, d_k][batch_size, chunks * chunk_len, n_heads * d_k]

    369        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

    应用最后的线性图层。结果将有形状[batch_size, chunks * chunk_len, d_model]

    373        h = self.output(h)

    向左追加chunk_len - 1 零嵌入;即右移回去

    376        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

    截断并添加剩余连接

    379        return h[:, :h_res.shape[1]] + h_res

    位置前馈层

    它由两个线性层和中间的激活层组成。

    382class FeedForward(nn.Module):
    • d_model 是变压器嵌入中的特征数
    • d_ff 是隐藏图层中的数字要素
    389    def __init__(self, d_model: int, d_ff: int):
    395        super().__init__()

    两个线性层

    398        self.lin1 = nn.Linear(d_model, d_ff)
    399        self.lin2 = nn.Linear(d_ff, d_model)

    ReLU 激活

    402        self.act = nn.ReLU()

    规范前层

    405        self.norm = nn.LayerNorm(d_model)

    h 是形状的嵌入[batch_size, seq_len, d_model]

    407    def forward(self, h: torch.Tensor):

    剩余

    413        h_res = h

    规范前

    415        h = self.norm(h)

    第一个线性层

    417        h = self.lin1(h)

    激活

    419        h = self.act(h)

    第二个线性层

    421        h = self.lin2(h)

    添加剩余连接

    424        return h + h_res

    最近邻编码器

    此模块对检索到的最近邻进行编码

    427class NearestNeighborEncoder(nn.Module):
    • chunk_len 是区块的长度
    • n_layer 是编码器中的层数
    • ca_layers 是交叉关注的层次吗
    • d_model 是嵌入中要素的数量
    • n_heads 是注意层中的头部数量
    • d_k 是注意头的大小
    • d_ff 是前馈网络隐藏层的大小
    434    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
    435                 d_model: int, n_heads: int, d_k: int, d_ff: int):
    446        super().__init__()
    447        self.ca_layers = ca_layers
    448        self.chunk_len = chunk_len

    交叉注意层

    450        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])

    双向自我关注层

    452        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])

    前馈图层

    454        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

    预归一化层

    457        self.norm_h = nn.LayerNorm(d_model)
    • e 是检索到的最近邻的令牌嵌入,形状为[batch_size, chunks, neighbors, neighbor_len, d_model]
    • h is 是形状的输入令牌嵌入[batch_size, seq_len, d_model]

    区块和邻居是并行处理的。

    459    def forward(self, e: torch.Tensor, h: torch.Tensor):

    塑造身材

    472        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

    475        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

    规范前

    478        h_split = self.norm_h(h_split)

    保留交叉关注层的索引

    481        p_ca = 0

    适用于所有图层

    483        for p in range(len(self.attn)):

    双向自我关注

    486            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

    交叉注意如果

    489            if p in self.ca_layers:

    491                e = self.ca[p_ca](e, h_split)

    增加交叉注意力指数

    493                p_ca += 1

    前馈层

    496            e = self.ffw[p](e)

    返回

    499        return e

    复古模特

    这是复古解码器

    502class RetroModel(nn.Module):
    • v_vocab 是词汇表中代币的数量
    • d_model 是嵌入中要素的数量
    • n_layers 是解码器中的层数
    • ca_layers 是交叉关注的层次吗
    • chunk_len 是区块的长度
    • n_heads 是注意层中的头部数量
    • d_k 是注意头的大小
    • d_ff 是前馈网络隐藏层的大小
    • encoder 是最近邻编码器
    509    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
    510                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
    522        super().__init__()
    523
    524        self.ca_layers = ca_layers
    525        self.encoder = encoder

    令牌嵌入层

    528        self.emb = nn.Embedding(n_vocab, d_model)

    分块交叉注意力层

    530        self.cca = nn.ModuleList(
    531            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])

    注意层

    533        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])

    前馈图层

    535        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

    读出层

    537        self.read = nn.Linear(d_model, n_vocab)

    最近邻嵌入的预归一化层

    541        self.norm_e = nn.LayerNorm(d_model)
    • x 是形状的输入序列[batch_size, seq_len]
  • ret 是检索到的形状邻域[batch_size, chunks, neighbors, neighbor_len]
  • 543    def forward(self, x: torch.Tensor, ret: torch.Tensor):

    获取输入嵌入

    552        h = self.emb(x)

    检索到的邻居的嵌入

    我们对输入和邻居使用相同的嵌入

    558        ret_emb = self.emb(ret)

    保留分块交叉注意层的索引

    561        p_ca = 0

    适用于所有图层

    563        for p in range(len(self.attn)):

    因果自我关注

    565            h = self.attn[p](h)

    在第一层之前获取编码器嵌入

    569            if self.ca_layers and p == min(self.ca_layers):

    我们将的嵌入传递给编码器。

    573                e = self.encoder(ret_emb, h)

    规范化编码器嵌入

    575                e = self.norm_e(e)

    大块交叉注意如果

    578            if p in self.ca_layers:

    580                h = self.cca[p_ca](h, e)

    递增分块交叉注意力指数

    582                p_ca += 1

    585            h = self.ffw[p](h)

    588        return self.read(h)

    使用虚假数据测试模型

    591def _test():
    595    chunk_len = 4
    596    d_model = 8
    597    d_ff = 32
    598    n_heads = 2
    599    d_k = 4
    600
    601    device = torch.device('cuda:0')
    602
    603    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
    604                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))
    605
    606    m.to(device)
    607    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
    608    ret = [
    609        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    610        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    611    ]
    612    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))
    613
    614    inspect(res)

    618if __name__ == '__main__':
    619    _test()