19from typing import Optional
20
21import torch
22from torch import nn25class GMLPBlock(nn.Module):d_model
是的维度 ()d_ffn
是的维度seq_len
是令牌序列的长度 ()46 def __init__(self, d_model: int, d_ffn: int, seq_len: int):52 super().__init__()Pre-Norm 的标准化层
54 self.norm = nn.LayerNorm([d_model])激活功能
56 self.activation = nn.GELU()投影层
58 self.proj1 = nn.Linear(d_model, d_ffn)空间门控单元
60 self.sgu = SpacialGatingUnit(d_ffn, seq_len)投影层
62 self.proj2 = nn.Linear(d_ffn // 2, d_model)x
是形状的输入嵌入张量[seq_len, batch_size, d_model]
mask
是形状的布尔掩码[seq_len, seq_len, 1]
,用于控制标记在彼此之间的可见性。68 def forward(self, *, x: torch.Tensor, mask: Optional[torch.Tensor] = None):保留一份用于快捷方式连接的副本
75 shortcut = x规范化
77 x = self.norm(x)投射和激活
79 z = self.activation(self.proj1(x))空间门控单元
81 z = self.sgu(z, mask)最终投影
83 z = self.proj2(z)添加快捷方式连接
86 return z + shortcut89class SpacialGatingUnit(nn.Module):d_z
是的维度seq_len
是序列长度99 def __init__(self, d_z: int, seq_len: int):104 super().__init__()应用之前的标准化层
106 self.norm = nn.LayerNorm([d_z // 2])111 self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)115 self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)z
是形状的输入[seq_len, batch_size, d_z]
mask
is 是形状的布尔掩码[seq_len, seq_len, 1]
,用于控制标记在彼此之间的可见性。尺寸的最后一个维度1
是批次,这是我们在其他变压器实现中使用的,为了兼容性而留下。117 def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):获取序列长度
126 seq_len = z.shape[0]拆分为和
128 z1, z2 = torch.chunk(z, 2, dim=-1)检查口罩
131 if mask is not None:mask
有形状[seq_len_q, seq_len_k, batch_size]
。批次维度应为 size,1
因为此实现仅支持批次中所有样本的相同掩码。
135 assert mask.shape[0] == 1 or mask.shape[0] == seq_len
136 assert mask.shape[1] == seq_len这里我们只支持所有样本使用相同的掩码
138 assert mask.shape[2] == 1移除批量维度
140 mask = mask[:, :, 0]之前进行标准化
143 z2 = self.norm(z2)获取权重矩阵;如果大于seq_len
145 weight = self.weight[:seq_len, :seq_len]150 if mask is not None:
151 weight = weight * mask154 z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]157 return z1 * z2