This is an annotated PyTorch experiment to train a Masked Language Model.
11from typing import List
12
13import torch
14from torch import nn
15
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM
28class TransformerMLM(nn.Module):
encoder
is the transformer Encoder src_embed
is the token embedding module (with positional encodings) generator
is the final fully connected layer that gives the logits.33 def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40 super().__init__()
41 self.generator = generator
42 self.src_embed = src_embed
43 self.encoder = encoder
45 def forward(self, x: torch.Tensor):
Get the token embeddings with positional encodings
47 x = self.src_embed(x)
Transformer encoder
49 x = self.encoder(x, None)
Logits for the output
51 y = self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
55 return y, None
This inherits from NLPAutoRegressionConfigs
because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.
58class Configs(NLPAutoRegressionConfigs):
MLM model
69 model: TransformerMLM
Transformer
71 transformer: TransformerConfigs
Number of tokens
74 n_tokens: int = 'n_tokens_mlm'
Tokens that shouldn't be masked
76 no_mask_tokens: List[int] = []
Probability of masking a token
78 masking_prob: float = 0.15
Probability of replacing the mask with a random token
80 randomize_prob: float = 0.1
Probability of replacing the mask with original token
82 no_change_prob: float = 0.1
Masked Language Model (MLM) class to generate the mask
84 mlm: MLM
[MASK]
token
87 mask_token: int
[PADDING]
token
89 padding_token: int
Prompt to sample
92 prompt: str = [
93 "We are accounted poor citizens, the patricians good.",
94 "What authority surfeits on would relieve us: if they",
95 "would yield us but the superfluity, while it were",
96 "wholesome, we might guess they relieved us humanely;",
97 "but they think we are too dear: the leanness that",
98 "afflicts us, the object of our misery, is as an",
99 "inventory to particularise their abundance; our",
100 "sufferance is a gain to them Let us revenge this with",
101 "our pikes, ere we become rakes: for the gods know I",
102 "speak this in hunger for bread, not in thirst for revenge.",
103 ]
105 def init(self):
[MASK]
token
111 self.mask_token = self.n_tokens - 1
[PAD]
token
113 self.padding_token = self.n_tokens - 2
Masked Language Model (MLM) class to generate the mask
116 self.mlm = MLM(padding_token=self.padding_token,
117 mask_token=self.mask_token,
118 no_mask_tokens=self.no_mask_tokens,
119 n_tokens=self.n_tokens,
120 masking_prob=self.masking_prob,
121 randomize_prob=self.randomize_prob,
122 no_change_prob=self.no_change_prob)
Accuracy metric (ignore the labels equal to [PAD]
)
125 self.accuracy = Accuracy(ignore_index=self.padding_token)
Cross entropy loss (ignore the labels equal to [PAD]
)
127 self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)
129 super().init()
131 def step(self, batch: any, batch_idx: BatchIndex):
Move the input to the device
137 data = batch[0].to(self.device)
Update global step (number of tokens processed) when in training mode
140 if self.mode.is_train:
141 tracker.add_global_step(data.shape[0] * data.shape[1])
Get the masked input and labels
144 with torch.no_grad():
145 data, labels = self.mlm(data)
Whether to capture model outputs
148 with self.mode.update(is_log_activations=batch_idx.is_last):
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.
152 output, *_ = self.model(data)
Calculate and log the loss
155 loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156 tracker.add("loss.", loss)
Calculate and log accuracy
159 self.accuracy(output, labels)
160 self.accuracy.track()
Train the model
163 if self.mode.is_train:
Calculate gradients
165 loss.backward()
Clip gradients
167 torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)
Take optimizer step
169 self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
171 if batch_idx.is_last:
172 tracker.add('model', self.model)
Clear the gradients
174 self.optimizer.zero_grad()
Save the tracked metrics
177 tracker.save()
179 @torch.no_grad()
180 def sample(self):
Empty tensor for data filled with [PAD]
.
186 data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)
Add the prompts one by one
188 for i, p in enumerate(self.prompt):
Get token indexes
190 d = self.text.text_to_i(p)
Add to the tensor
192 s = min(self.seq_len, len(d))
193 data[:s, i] = d[:s]
Move the tensor to current device
195 data = data.to(self.device)
Get masked input and labels
198 data, labels = self.mlm(data)
Get model outputs
200 output, *_ = self.model(data)
Print the samples generated
203 for j in range(data.shape[1]):
Collect output from printing
205 log = []
For each token
207 for i in range(len(data)):
If the label is not [PAD]
209 if labels[i, j] != self.padding_token:
Get the prediction
211 t = output[i, j].argmax().item()
If it's a printable character
213 if t < len(self.text.itos):
Correct prediction
215 if t == labels[i, j]:
216 log.append((self.text.itos[t], Text.value))
Incorrect prediction
218 else:
219 log.append((self.text.itos[t], Text.danger))
If it's not a printable character
221 else:
222 log.append(('*', Text.danger))
If the label is [PAD]
(unmasked) print the original.
224 elif data[i, j] < len(self.text.itos):
225 log.append((self.text.itos[data[i, j]], Text.subtle))
228 logger.log(log)
Number of tokens including [PAD]
and [MASK]
231@option(Configs.n_tokens)
232def n_tokens_mlm(c: Configs):
236 return c.text.n_tokens + 2
239@option(Configs.transformer)
240def _transformer_configs(c: Configs):
We use our configurable transformer implementation
247 conf = TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
249 conf.n_src_vocab = c.n_tokens
250 conf.n_tgt_vocab = c.n_tokens
Embedding size
252 conf.d_model = c.d_model
255 return conf
Create classification model
258@option(Configs.model)
259def _model(c: Configs):
263 m = TransformerMLM(encoder=c.transformer.encoder,
264 src_embed=c.transformer.src_embed,
265 generator=c.transformer.generator).to(c.device)
266
267 return m
270def main():
Create experiment
272 experiment.create(name="mlm")
Create configs
274 conf = Configs()
Override configurations
276 experiment.configs(conf, {
Batch size
278 'batch_size': 64,
Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.
281 'seq_len': 32,
Train for 1024 epochs.
284 'epochs': 1024,
Switch between training and validation for times per epoch
287 'inner_iterations': 1,
Transformer configurations (same as defaults)
290 'd_model': 128,
291 'transformer.ffn.d_ff': 256,
292 'transformer.n_heads': 8,
293 'transformer.n_layers': 6,
Use Noam optimizer
296 'optimizer.optimizer': 'Noam',
297 'optimizer.learning_rate': 1.,
298 })
Set models for saving and loading
301 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
304 with experiment.start():
Run training
306 conf.run()
310if __name__ == '__main__':
311 main()