注意 MLP (GmLP)

这是 P yTorch 对《注意 MLP》一文的实现。

本文介绍了一种基于多层感知器(MLP)的带有门控的架构,他们将其命名为 gmLP。它由一堆 gmLP 块组成。

这是基于 GmLP 模型的自回归模型的训练代码

19from typing import Optional
20
21import torch
22from torch import nn

gmLP Block

每个模块对输入嵌入进行以下转换,其中是序列长度,是嵌入的维度:

其中是可学习的投影权重。是下面定义的空间门控单元。的输出维度将为的一半是一个激活函数,比如 GeLU

25class GMLPBlock(nn.Module):
  • d_model 是的维度 ()
  • d_ffn 是的维度
  • seq_len 是令牌序列的长度 ()
46    def __init__(self, d_model: int, d_ffn: int, seq_len: int):
52        super().__init__()

Pre-Norm 的标准化层

54        self.norm = nn.LayerNorm([d_model])

激活功能

56        self.activation = nn.GELU()

投影层

58        self.proj1 = nn.Linear(d_model, d_ffn)

空间门控单元

60        self.sgu = SpacialGatingUnit(d_ffn, seq_len)

投影层

62        self.proj2 = nn.Linear(d_ffn // 2, d_model)

嵌入大小(编码器需要。我们使用变压器架构中的编码器模块,并插入 GmLP 模块作为变压器层的替代品。

66        self.size = d_model
  • x 是形状的输入嵌入张量[seq_len, batch_size, d_model]
  • mask 是形状的布尔掩码[seq_len, seq_len, 1] ,用于控制标记在彼此之间的可见性。
68    def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):

保留一份用于快捷方式连接的副本

75        shortcut = x

规范化

77        x = self.norm(x)

投射和激活

79        z = self.activation(self.proj1(x))

空间门控单元

81        z = self.sgu(z, mask)

最终投影

83        z = self.proj2(z)

添加快捷方式连接

86        return z + shortcut

空间门控单元

其中,是沿序列维度的线性变换,是逐元素乘法。被分成两个大小相等的部分,沿着通道尺寸(嵌入维度)。

89class SpacialGatingUnit(nn.Module):
  • d_z 是的维度
  • seq_len 是序列长度
99    def __init__(self, d_z: int, seq_len: int):
104        super().__init__()

应用之前的标准化层

106        self.norm = nn.LayerNorm([d_z // 2])

重量

本文指出,重要的是将权重初始化为较小的值,并将偏差初始化,这样在初始训练期间就接近身份(拆分除外)。

111        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)

重量

本文指出,将偏见初始化为.

115        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
  • z 是形状的输入[seq_len, batch_size, d_z]
  • mask is 是形状的布尔掩码[seq_len, seq_len, 1] ,用于控制标记在彼此之间的可见性。尺寸的最后一个维度1 是批次,这是我们在其他变压器实现中使用的,为了兼容性而留下。
117    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):

获取序列长度

126        seq_len = z.shape[0]

分为

128        z1, z2 = torch.chunk(z, 2, dim=-1)

检查口罩

131        if mask is not None:

mask 有形状[seq_len_q, seq_len_k, batch_size] 。批次维度应为 size,1 因为此实现仅支持批次中所有样本的相同掩码。

135            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136            assert mask.shape[1] == seq_len

这里我们只支持所有样本使用相同的掩码

138            assert mask.shape[2] == 1

移除批量维度

140            mask = mask[:, :, 0]

之前进行标准化

143        z2 = self.norm(z2)

获取权重矩阵;如果大于seq_len

145        weight = self.weight[:seq_len, :seq_len]

将遮罩应用于权重。

如果,则不会从令牌中获取任何信息

150        if mask is not None:
151            weight = weight * mask

154        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]

157        return z1 * z2