这是纸质《开关变形金刚:以简单高效的稀疏度扩展到万亿个参数模型》的微型 PyTorch 实现。我们的实现只有几百万个参数,不对并行分布式训练进行建模。它进行单个 GPU 训练,但我们实现了论文中描述的切换概念。
Switch Transformer 通过根据令牌在参数之间切换,为每个令牌使用不同的参数。因此,只为每个代币选择了一小部分参数。因此,您可以拥有更多参数,但计算成本更低。
切换发生在每个变压器模块的位置前馈网络 (FFN) 上。位置前馈网络由两个按顺序完全连接的层组成。在交换机变压器中,我们有多个 FFN(多位专家),我们根据路由器选择使用哪一个。输出是一组用于选择 FFN 的概率,我们选择概率最高的概率,然后仅对其进行评估。因此,从本质上讲,计算成本与拥有单个 FFN 相同。在我们的实现中,当你有许多或大型 FFN 时,这种并行化效果不佳,因为这一切都发生在单个 GPU 上。在分布式设置中,你会将每个 FFN(每个都很大)放在不同的设备上。
本文引入了另一个损失术语来平衡专家(FFN)之间的负载,并讨论了路由不平衡时丢弃代币的问题。
这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练开关变压器的笔记本。
39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.utils import clone_module_list
48class SwitchFeedForward(Module):
capacity_factor
是每个 EA 的容量作为相对于理想平衡负载的一个因素drop_tokens
指定如果路由到专家的令牌多于容量,是否丢弃令牌is_scale_prob
指定是否将 FFN 的输入乘以路由概率n_experts
是专家的数量expert
是专家层,一个 FFN 模块d_model
是令牌嵌入中的要素数量d_ff
是 FFN 隐藏层中的要素数量dropout
FFN 中的辍学概率是多少53 def __init__(self, *,
54 capacity_factor: float,
55 drop_tokens: bool,
56 is_scale_prob: bool,
57 n_experts: int,
58 expert: FeedForward,
59 d_model: int):
70 super().__init__()
71
72 self.capacity_factor = capacity_factor
73 self.is_scale_prob = is_scale_prob
74 self.n_experts = n_experts
75 self.drop_tokens = drop_tokens
复制 FFN
78 self.experts = clone_module_list(expert, n_experts)
路由层和 softmax
80 self.switch = nn.Linear(d_model, n_experts)
81 self.softmax = nn.Softmax(dim=-1)
x
是带形状的开关模块的输入[seq_len, batch_size, d_model]
83 def forward(self, x: torch.Tensor):
捕获形状以便稍后更改形状
89 seq_len, batch_size, d_model = x.shape
展平序列和批次维度
91 x = x.view(-1, d_model)
获取每个令牌的路由概率。其中是专家的数量n_experts
,是令牌嵌入的线性变换。
97 route_prob = self.softmax(self.switch(x))
获取最大路由概率和路线。我们以最高的概率路由到智能交易
系统101 route_prob_max, routes = torch.max(route_prob, dim=-1)
获取每位专家的代币索引
104 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
初始化空张量以存储输出
107 final_output = x.new_zeros(x.shape)
每位专家的能力。
113 capacity = int(self.capacity_factor * len(x) / self.n_experts)
发送给每位专家的代币数量。
115 counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
初始化已丢弃令牌的空列表
118 dropped = []
如果drop_tokens
是,则仅丢弃令牌True
。
120 if self.drop_tokens:
在每位专家身上丢掉代币
122 for i in range(self.n_experts):
如果智能交易没有超出容量,请忽略
124 if len(indexes_list[i]) <= capacity:
125 continue
在丢弃之前随机播放索引
127 indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
收集超过容量的代币作为丢弃的令牌
129 dropped.append(indexes_list[i][capacity:])
只保留与专家容量相等的代币
131 indexes_list[i] = indexes_list[i][:capacity]
获取专家 FFN 的输出
134 expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
分配给最终输出
137 for i in range(self.n_experts):
138 final_output[indexes_list[i], :] = expert_output[i]
通过掉落的代币
141 if dropped:
142 dropped = torch.cat(dropped)
143 final_output[dropped, :] = x[dropped, :]
144
145 if self.is_scale_prob:
将智能交易的输出乘以概率
147 final_output = final_output * route_prob_max.view(-1, 1)
148 else:
不要缩放值,而是乘以渐变流动(这是我们尝试过的)。
151 final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)
将最终输出的形状改回[seq_len, batch_size, d_model]
154 final_output = final_output.view(seq_len, batch_size, d_model)
165 return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
d_model
是令牌嵌入的大小attn
是注意力模块feed_forward
是前馈模块(在本例中为交换模块)dropout_prob
是自我关注和 FFN 后退学的概率176 def __init__(self, *,
177 d_model: int,
178 attn: MultiHeadAttention,
179 feed_forward: SwitchFeedForward,
180 dropout_prob: float):
187 super().__init__()
188 self.size = d_model
189 self.attn = attn
190 self.feed_forward = feed_forward
191 self.dropout = nn.Dropout(dropout_prob)
192 self.norm_self_attn = nn.LayerNorm([d_model])
193 self.norm_ff = nn.LayerNorm([d_model])
195 def forward(self, *,
196 x: torch.Tensor,
197 mask: torch.Tensor):
在进行自我注意之前对向量进行归一化
199 z = self.norm_self_attn(x)
通过自我关注,即关键和价值来自自我
201 self_attn = self.attn(query=z, key=z, value=z, mask=mask)
添加自我关注的结果
203 x = x + self.dropout(self_attn)
标准化以进行前馈
206 z = self.norm_ff(x)
通过交换前馈网络
208 ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
将前馈结果添加回来
210 x = x + self.dropout(ff)
211
212 return x, counts, route_prob, n_dropped, route_prob_max
215class SwitchTransformer(Module):
220 def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
221 super().__init__()
制作变压器层的副本
223 self.layers = clone_module_list(layer, n_layers)
最终归一化层
225 self.norm = nn.LayerNorm([layer.size])
227 def forward(self, x: torch.Tensor, mask: torch.Tensor):
穿过每个变压器层
229 counts, route_prob, n_dropped, route_prob_max = [], [], [], []
230 for layer in self.layers:
231 x, f, p, n_d, p_max = layer(x=x, mask=mask)
232 counts.append(f)
233 route_prob.append(p)
234 n_dropped.append(n_d)
235 route_prob_max.append(p_max)
最后,对向量进行归一化
237 x = self.norm(x)
239 return x, torch.stack(counts), torch.stack(route_prob), n_dropped, torch.stack(route_prob_max)