This is an annotated PyTorch experiment to train a switch transformer.
14import torch
15import torch.nn as nn
16
17from labml import experiment, tracker
18from labml.configs import option
19from labml_nn.helpers.trainer import BatchIndex
20from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23class AutoregressiveModel(nn.Module):
28 def __init__(self, n_vocab: int, d_model: int, transformer: nn.Module):
29 super().__init__()
Token embedding module
31 self.src_embed = nn.Embedding(n_vocab, d_model)
Transformer
33 self.transformer = transformer
Final layer
35 self.generator = nn.Linear(d_model, n_vocab)
36 self.mask = None
38 def forward(self, x: torch.Tensor):
Initialize the subsequent mask
40 if self.mask is None or self.mask.size(0) != len(x):
41 from labml_nn.transformers.utils import subsequent_mask
42 self.mask = subsequent_mask(len(x)).to(x.device)
Token embeddings
44 x = self.src_embed(x)
Run it through the transformer
46 res, counts, route_prob, n_dropped, route_prob_max = self.transformer(x, self.mask)
Generate logits of the next token
48 res = self.generator(res)
50 return res, counts, route_prob, n_dropped, route_prob_max
This extends NLPAutoRegressionConfigs
.
The default configs can and will be over-ridden when we start the experiment
53class Configs(NLPAutoRegressionConfigs):
62 model: AutoregressiveModel
63 transformer: nn.Module
Token embedding size
66 d_model: int = 128
Number of attention heads
68 heads: int = 4
Dropout probability
70 dropout: float = 0.0
Number of features in FFN hidden layer
72 d_ff: int = 256
Number of transformer layers
74 n_layers: int = 6
Number of experts
76 n_experts: int = 4
Load balancing coefficient
78 load_balancing_loss_ceof = 0.01
Whether to scale the chosen expert outputs by the routing probability
80 is_scale_prob: bool = True
Whether to drop tokens
82 drop_tokens: bool = False
Capacity factor to determine capacity of each model
84 capacity_factor: float = 1.0
86 def init(self):
87 super().init()
Initialize tracking indicators
89 tracker.set_scalar("lb_loss.*", False)
90 tracker.set_scalar("route.*", False)
91 tracker.set_scalar("dropped.*", False)
93 def step(self, batch: any, batch_idx: BatchIndex):
Move data to the device
99 data, target = batch[0].to(self.device), batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
102 if self.mode.is_train:
103 tracker.add_global_step(data.shape[0] * data.shape[1])
Get model outputs.
106 output, counts, route_prob, n_dropped, route_prob_max = self.model(data)
Calculate and cross entropy loss
109 cross_entropy_loss = self.loss_func(output, target)
Total number of tokens processed, , in the current batch
111 total = counts.sum(dim=-1, keepdims=True)
Fraction of tokens routed to each expert is the count of tokens where the argmax of is equal to .
115 route_frac = counts / total
Mean routing probability
118 route_prob = route_prob / total
Load balancing loss is the loss for a single layer and here we are taking the sum of losses across all layers.
123 load_balancing_loss = self.n_experts * (route_frac * route_prob).sum()
Track stats
126 tracker.add('dropped.', total.new_tensor(n_dropped) / total)
127 tracker.add('route.min.', route_frac.min())
128 tracker.add('route.max.', route_frac.max())
129 tracker.add('route.std.', route_frac.std())
130 tracker.add('route.max_prob.', route_prob_max)
131 tracker.add("loss.", cross_entropy_loss)
132 tracker.add("lb_loss.", load_balancing_loss)
Combined loss. The load balancing loss is multiplied by a coefficient which is set to something small like .
137 loss = cross_entropy_loss + self.load_balancing_loss_ceof * load_balancing_loss
Calculate and log accuracy
140 self.accuracy(output, target)
141 self.accuracy.track()
Train the model
144 if self.mode.is_train:
Calculate gradients
146 loss.backward()
Clip gradients
148 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
150 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
152 if batch_idx.is_last:
153 tracker.add('model', self.model)
Clear the gradients
155 self.optimizer.zero_grad()
Save the tracked metrics
158 tracker.save()
161@option(Configs.model)
162def autoregressive_model(c: Configs):
166 m = AutoregressiveModel(c.n_tokens, c.d_model, c.transformer)
167 return m.to(c.device)
170@option(Configs.transformer)
171def switch_transformer(c: Configs):
175 from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
176 from labml_nn.transformers import MultiHeadAttention
177 from labml_nn.transformers.feed_forward import FeedForward
178
179 return SwitchTransformer(
180 SwitchTransformerLayer(d_model=c.d_model,
181 attn=MultiHeadAttention(c.heads, c.d_model, c.dropout),
182 feed_forward=SwitchFeedForward(capacity_factor=c.capacity_factor,
183 drop_tokens=c.drop_tokens,
184 is_scale_prob=c.is_scale_prob,
185 n_experts=c.n_experts,
186 expert=FeedForward(c.d_model, c.d_ff, c.dropout),
187 d_model=c.d_model),
188 dropout_prob=c.dropout),
189 c.n_layers)
192def main():
Create experiment
197 experiment.create(name="switch_transformer", comment='')
Create configs
199 conf = Configs()
Load configurations
201 experiment.configs(conf,
A dictionary of configurations to override
203 {'tokenizer': 'character',
204 'text': 'tiny_shakespeare',
205 'optimizer.learning_rate': 1.,
206 'optimizer.optimizer': 'Noam',
207 'prompt': 'It is',
208 'prompt_separator': '',
209
210 'transformer': 'switch_transformer',
211 'n_experts': 4,
212
213 'drop_tokens': True,
214 'capacity_factor': 1.2,
215
216 'train_loader': 'shuffled_train_loader',
217 'valid_loader': 'shuffled_valid_loader',
218
219 'seq_len': 64,
220 'epochs': 128,
221 'batch_size': 32,
222 'inner_iterations': 25,
223 })
Set models for saving and loading
226 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
229 with experiment.start():
TrainValidConfigs.run
231 conf.run()
235if __name__ == '__main__':
236 main()