开关变压器

这是纸质《开关变形金刚:以简单高效的稀疏度扩展到万亿个参数模型》的微型 PyTorch 实现。我们的实现只有几百万个参数,不对并行分布式训练进行建模。它进行单个 GPU 训练,但我们实现了论文中描述的切换概念。

Switch Transformer 通过根据令牌在参数之间切换,为每个令牌使用不同的参数。因此,只为每个代币选择了一小部分参数。因此,您可以拥有更多参数,但计算成本更低。

切换发生在每个变压器模块的位置前馈网络 (FFN) 上。位置前馈网络由两个按顺序完全连接的层组成。在交换机变压器中,我们有多个 FFN(多位专家),我们根据路由器选择使用哪一个。输出是一组用于选择 FFN 的概率,我们选择概率最高的概率,然后仅对其进行评估。因此,从本质上讲,计算成本与拥有单个 FFN 相同。在我们的实现中,当你有许多或大型 FFN 时,这种并行化效果不佳,因为这一切都发生在单个 GPU 上。在分布式设置中,你会将每个 FFN(每个都很大)放在不同的设备上。

本文引入了另一个损失术语来平衡专家(FFN)之间的负载,并讨论了路由不平衡时丢弃代币的问题。

这是训练代码和一本用于在 Tiny Shakespeare 数据集上训练开关变压器的笔记本。

Open In Colab

39import torch
40from torch import nn
41
42from labml_helpers.module import Module
43from labml_nn.transformers.feed_forward import FeedForward
44from labml_nn.transformers.mha import MultiHeadAttention
45from labml_nn.utils import clone_module_list

在多个 FFN 之间路由

48class SwitchFeedForward(Module):
  • capacity_factor 是每个 EA 的容量作为相对于理想平衡负载的一个因素
  • drop_tokens 指定如果路由到专家的令牌多于容量,是否丢弃令牌
  • is_scale_prob 指定是否将 FFN 的输入乘以路由概率
  • n_experts 是专家的数量
  • expert 是专家层,一个 FFN 模块
  • d_model 是令牌嵌入中的要素数量
  • d_ff 是 FFN 隐藏层中的要素数量
  • dropout FFN 中的辍学概率是多少
  • 53    def __init__(self, *,
    54                 capacity_factor: float,
    55                 drop_tokens: bool,
    56                 is_scale_prob: bool,
    57                 n_experts: int,
    58                 expert: FeedForward,
    59                 d_model: int):
    70        super().__init__()
    71
    72        self.capacity_factor = capacity_factor
    73        self.is_scale_prob = is_scale_prob
    74        self.n_experts = n_experts
    75        self.drop_tokens = drop_tokens

    复制 FFN

    78        self.experts = clone_module_list(expert, n_experts)

    路由层和 softmax

    80        self.switch = nn.Linear(d_model, n_experts)
    81        self.softmax = nn.Softmax(dim=-1)
    • x 是带形状的开关模块的输入[seq_len, batch_size, d_model]
    83    def forward(self, x: torch.Tensor):

    捕获形状以便稍后更改形状

    89        seq_len, batch_size, d_model = x.shape

    展平序列和批次维度

    91        x = x.view(-1, d_model)

    获取每个令牌的路由概率。其中是专家的数量n_experts是令牌嵌入的线性变换。

    97        route_prob = self.softmax(self.switch(x))

    获取最大路由概率和路线。我们以最高的概率路由到智能交易

    系统
    101        route_prob_max, routes = torch.max(route_prob, dim=-1)

    获取每位专家的代币索引

    104        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]

    初始化空张量以存储输出

    107        final_output = x.new_zeros(x.shape)

    每位专家的能力。

    113        capacity = int(self.capacity_factor * len(x) / self.n_experts)

    发送给每位专家的代币数量。

    115        counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])

    初始化已丢弃令牌的空列表

    118        dropped = []

    如果drop_tokens 是,则仅丢弃令牌True

    120        if self.drop_tokens:

    在每位专家身上丢掉代币

    122            for i in range(self.n_experts):

    如果智能交易没有超出容量,请忽略

    124                if len(indexes_list[i]) <= capacity:
    125                    continue

    在丢弃之前随机播放索引

    127                indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]

    收集超过容量的代币作为丢弃的令牌

    129                dropped.append(indexes_list[i][capacity:])

    只保留与专家容量相等的代币

    131                indexes_list[i] = indexes_list[i][:capacity]

    获取专家 FFN 的输出

    134        expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]

    分配给最终输出

    137        for i in range(self.n_experts):
    138            final_output[indexes_list[i], :] = expert_output[i]

    通过掉落的代币

    141        if dropped:
    142            dropped = torch.cat(dropped)
    143            final_output[dropped, :] = x[dropped, :]
    144
    145        if self.is_scale_prob:

    将智能交易的输出乘以概率

    147            final_output = final_output * route_prob_max.view(-1, 1)
    148        else:

    不要缩放值,而是乘以渐变流动(这是我们尝试过的)。

    151            final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)

    将最终输出的形状改回[seq_len, batch_size, d_model]

    154        final_output = final_output.view(seq_len, batch_size, d_model)

    返回

    • 最终输出
    • 发送给每位专家的代币数量
    • 每个 EA 的概率总和
    • 丢弃的代币数量。
    • 所选 EA 的路由概率

    这些用于负载平衡丢失和日志记录

    165        return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max

    开关变压器块

    这与普通变压器模块相同,用于处理开关前馈模块的额外输出。

    168class SwitchTransformerLayer(Module):
    • d_model 是令牌嵌入的大小
    • attn 是注意力模块
    • feed_forward 是前馈模块(在本例中为交换模块)
    • dropout_prob 是自我关注和 FFN 后退学的概率
    176    def __init__(self, *,
    177                 d_model: int,
    178                 attn: MultiHeadAttention,
    179                 feed_forward: SwitchFeedForward,
    180                 dropout_prob: float):
    187        super().__init__()
    188        self.size = d_model
    189        self.attn = attn
    190        self.feed_forward = feed_forward
    191        self.dropout = nn.Dropout(dropout_prob)
    192        self.norm_self_attn = nn.LayerNorm([d_model])
    193        self.norm_ff = nn.LayerNorm([d_model])
    195    def forward(self, *,
    196                x: torch.Tensor,
    197                mask: torch.Tensor):

    在进行自我注意之前对向量进行归一化

    199        z = self.norm_self_attn(x)

    通过自我关注,即关键和价值来自自我

    201        self_attn = self.attn(query=z, key=z, value=z, mask=mask)

    添加自我关注的结果

    203        x = x + self.dropout(self_attn)

    标准化以进行前馈

    206        z = self.norm_ff(x)

    通过交换前馈网络

    208        ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)

    将前馈结果添加回来

    210        x = x + self.dropout(ff)
    211
    212        return x, counts, route_prob, n_dropped, route_prob_max

    开关变压器

    215class SwitchTransformer(Module):
    220    def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
    221        super().__init__()

    制作变压器层的副本

    223        self.layers = clone_module_list(layer, n_layers)

    最终归一化层

    225        self.norm = nn.LayerNorm([layer.size])
    227    def forward(self, x: torch.Tensor, mask: torch.Tensor):

    穿过每个变压器层

    229        counts, route_prob, n_dropped, route_prob_max = [], [], [], []
    230        for layer in self.layers:
    231            x, f, p, n_d, p_max = layer(x=x, mask=mask)
    232            counts.append(f)
    233            route_prob.append(p)
    234            n_dropped.append(n_d)
    235            route_prob_max.append(p_max)

    最后,对向量进行归一化

    237        x = self.norm(x)

    239        return x, torch.stack(counts), torch.stack(route_prob), n_dropped, torch.stack(route_prob_max)