This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn't do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn't parallelize well when you have many or large FFNs since it's all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.
Here's the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
39import torch
40from torch import nn
41
42from labml_nn.transformers.feed_forward import FeedForward
43from labml_nn.transformers.mha import MultiHeadAttention
44from labml_nn.utils import clone_module_list
47class SwitchFeedForward(nn.Module):
capacity_factor
is the capacity of each expert as a factor relative to ideally balanced load drop_tokens
specifies whether to drop tokens if more tokens are routed to an expert than the capacity is_scale_prob
specifies whether to multiply the input to the FFN by the routing probability n_experts
is the number of experts expert
is the expert layer, a FFN module d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN dropout
is dropout probability in the FFN52 def __init__(self, *,
53 capacity_factor: float,
54 drop_tokens: bool,
55 is_scale_prob: bool,
56 n_experts: int,
57 expert: FeedForward,
58 d_model: int):
69 super().__init__()
70
71 self.capacity_factor = capacity_factor
72 self.is_scale_prob = is_scale_prob
73 self.n_experts = n_experts
74 self.drop_tokens = drop_tokens
make copies of the FFNs
77 self.experts = clone_module_list(expert, n_experts)
Routing layer and softmax
79 self.switch = nn.Linear(d_model, n_experts)
80 self.softmax = nn.Softmax(dim=-1)
x
is the input to the switching module with shape [seq_len, batch_size, d_model]
82 def forward(self, x: torch.Tensor):
Capture the shape to change shapes later
88 seq_len, batch_size, d_model = x.shape
Flatten the sequence and batch dimensions
90 x = x.view(-1, d_model)
Get routing probabilities for each of the tokens. where is the number of experts n_experts
and is the linear transformation of token embeddings.
96 route_prob = self.softmax(self.switch(x))
Get the maximum routing probabilities and the routes. We route to the expert with highest probability
100 route_prob_max, routes = torch.max(route_prob, dim=-1)
Get indexes of tokens going to each expert
103 indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]
Initialize an empty tensor to store outputs
106 final_output = x.new_zeros(x.shape)
Capacity of each expert.
112 capacity = int(self.capacity_factor * len(x) / self.n_experts)
Number of tokens routed to each expert.
114 counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
Initialize an empty list of dropped tokens
117 dropped = []
Only drop tokens if drop_tokens
is True
.
119 if self.drop_tokens:
Drop tokens in each of the experts
121 for i in range(self.n_experts):
Ignore if the expert is not over capacity
123 if len(indexes_list[i]) <= capacity:
124 continue
Shuffle indexes before dropping
126 indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
Collect the tokens over capacity as dropped tokens
128 dropped.append(indexes_list[i][capacity:])
Keep only the tokens upto the capacity of the expert
130 indexes_list[i] = indexes_list[i][:capacity]
Get outputs of the expert FFNs
133 expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
Assign to final output
136 for i in range(self.n_experts):
137 final_output[indexes_list[i], :] = expert_output[i]
Pass through the dropped tokens
140 if dropped:
141 dropped = torch.cat(dropped)
142 final_output[dropped, :] = x[dropped, :]
143
144 if self.is_scale_prob:
Multiply by the expert outputs by the probabilities
146 final_output = final_output * route_prob_max.view(-1, 1)
147 else:
Don't scale the values but multiply by so that the gradients flow (this is something we experimented with).
150 final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)
Change the shape of the final output back to [seq_len, batch_size, d_model]
153 final_output = final_output.view(seq_len, batch_size, d_model)
Return
These are used for the load balancing loss and logging
164 return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
This is the same as normal transformer block with handling extra outputs of switch feedforward module.
167class SwitchTransformerLayer(nn.Module):
d_model
is the token embedding size attn
is the attention module feed_forward
is the feed forward module (which is the switching module in this case) dropout_prob
is the probability of dropping out after self attention and FFN175 def __init__(self, *,
176 d_model: int,
177 attn: MultiHeadAttention,
178 feed_forward: SwitchFeedForward,
179 dropout_prob: float):
186 super().__init__()
187 self.size = d_model
188 self.attn = attn
189 self.feed_forward = feed_forward
190 self.dropout = nn.Dropout(dropout_prob)
191 self.norm_self_attn = nn.LayerNorm([d_model])
192 self.norm_ff = nn.LayerNorm([d_model])
194 def forward(self, *,
195 x: torch.Tensor,
196 mask: torch.Tensor):
Normalize the vectors before doing self attention
198 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
200 self_attn = self.attn(query=z, key=z, value=z, mask=mask)
Add the self attention results
202 x = x + self.dropout(self_attn)
Normalize for feed-forward
205 z = self.norm_ff(x)
Pass through the switching feed-forward network
207 ff, counts, route_prob, n_dropped, route_prob_max = self.feed_forward(z)
Add the feed-forward results back
209 x = x + self.dropout(ff)
210
211 return x, counts, route_prob, n_dropped, route_prob_max
214class SwitchTransformer(nn.Module):
219 def __init__(self, layer: SwitchTransformerLayer, n_layers: int):
220 super().__init__()
Make copies of the transformer layer
222 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
224 self.norm = nn.LayerNorm([layer.size])
226 def forward(self, x: torch.Tensor, mask: torch.Tensor):
Run through each transformer layer
228 counts, route_prob, n_dropped, route_prob_max = [], [], [], []
229 for layer in self.layers:
230 x, f, p, n_d, p_max = layer(x=x, mask=mask)
231 counts.append(f)
232 route_prob.append(p)
233 n_dropped.append(n_d)
234 route_prob_max.append(p_max)
Finally, normalize the vectors
236 x = self.norm(x)
238 return x, torch.stack(counts), torch.stack(route_prob), n_dropped, torch.stack(route_prob_max)