19from typing import Optional
20
21import torch
22from torch import nn
25class GMLPBlock(nn.Module):
d_model
是的维度 ()d_ffn
是的维度seq_len
是令牌序列的长度 ()46 def __init__(self, d_model: int, d_ffn: int, seq_len: int):
52 super().__init__()
Pre-Norm 的标准化层
54 self.norm = nn.LayerNorm([d_model])
激活功能
56 self.activation = nn.GELU()
投影层
58 self.proj1 = nn.Linear(d_model, d_ffn)
空间门控单元
60 self.sgu = SpacialGatingUnit(d_ffn, seq_len)
投影层
62 self.proj2 = nn.Linear(d_ffn // 2, d_model)
x
是形状的输入嵌入张量[seq_len, batch_size, d_model]
mask
是形状的布尔掩码[seq_len, seq_len, 1]
,用于控制标记在彼此之间的可见性。68 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
保留一份用于快捷方式连接的副本
75 shortcut = x
规范化
77 x = self.norm(x)
投射和激活
79 z = self.activation(self.proj1(x))
空间门控单元
81 z = self.sgu(z, mask)
最终投影
83 z = self.proj2(z)
添加快捷方式连接
86 return z + shortcut
89class SpacialGatingUnit(nn.Module):
d_z
是的维度seq_len
是序列长度99 def __init__(self, d_z: int, seq_len: int):
104 super().__init__()
应用之前的标准化层
106 self.norm = nn.LayerNorm([d_z // 2])
111 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
115 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)
z
是形状的输入[seq_len, batch_size, d_z]
mask
is 是形状的布尔掩码[seq_len, seq_len, 1]
,用于控制标记在彼此之间的可见性。尺寸的最后一个维度1
是批次,这是我们在其他变压器实现中使用的,为了兼容性而留下。117 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
获取序列长度
126 seq_len = z.shape[0]
拆分为和
128 z1, z2 = torch.chunk(z, 2, dim=-1)
检查口罩
131 if mask is not None:
mask
有形状[seq_len_q, seq_len_k, batch_size]
。批次维度应为 size,1
因为此实现仅支持批次中所有样本的相同掩码。
135 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136 assert mask.shape[1] == seq_len
这里我们只支持所有样本使用相同的掩码
138 assert mask.shape[2] == 1
移除批量维度
140 mask = mask[:, :, 0]
之前进行标准化
143 z2 = self.norm(z2)
获取权重矩阵;如果大于seq_len
145 weight = self.weight[:seq_len, :seq_len]
150 if mask is not None:
151 weight = weight * mask
154 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
157 return z1 * z2