Masked Language Model (MLM) Experiment

This is an annotated PyTorch experiment to train a Masked Language Model.

11from typing import List
13import torch
14from torch import nn
16from labml import experiment, tracker, logger
17from labml.configs import option
18from labml.logger import Text
19from labml_helpers.metrics.accuracy import Accuracy
20from labml_helpers.module import Module
21from labml_helpers.train_valid import BatchIndex
22from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
23from labml_nn.transformers import Encoder, Generator
24from labml_nn.transformers import TransformerConfigs
25from labml_nn.transformers.mlm import MLM

Transformer based model for MLM

28class TransformerMLM(nn.Module):
33    def __init__(self, *, encoder: Encoder, src_embed: Module, generator: Generator):
40        super().__init__()
41        self.generator = generator
42        self.src_embed = src_embed
43        self.encoder = encoder
45    def forward(self, x: torch.Tensor):

Get the token embeddings with positional encodings

47        x = self.src_embed(x)

Transformer encoder

49        x = self.encoder(x, None)

Logits for the output

51        y = self.generator(x)

Return results (second value is for state, since our trainer is used with RNNs also)

55        return y, None


This inherits from NLPAutoRegressionConfigs because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.

58class Configs(NLPAutoRegressionConfigs):

MLM model

69    model: TransformerMLM


71    transformer: TransformerConfigs

Number of tokens

74    n_tokens: int = 'n_tokens_mlm'

Tokens that shouldn't be masked

76    no_mask_tokens: List[int] = []

Probability of masking a token

78    masking_prob: float = 0.15

Probability of replacing the mask with a random token

80    randomize_prob: float = 0.1

Probability of replacing the mask with original token

82    no_change_prob: float = 0.1

Masked Language Model (MLM) class to generate the mask

84    mlm: MLM

[MASK] token

87    mask_token: int

[PADDING] token

89    padding_token: int

Prompt to sample

92    prompt: str = [
93        "We are accounted poor citizens, the patricians good.",
94        "What authority surfeits on would relieve us: if they",
95        "would yield us but the superfluity, while it were",
96        "wholesome, we might guess they relieved us humanely;",
97        "but they think we are too dear: the leanness that",
98        "afflicts us, the object of our misery, is as an",
99        "inventory to particularise their abundance; our",
100        "sufferance is a gain to them Let us revenge this with",
101        "our pikes, ere we become rakes: for the gods know I",
102        "speak this in hunger for bread, not in thirst for revenge.",
103    ]


105    def init(self):

[MASK] token

111        self.mask_token = self.n_tokens - 1

[PAD] token

113        self.padding_token = self.n_tokens - 2

Masked Language Model (MLM) class to generate the mask

116        self.mlm = MLM(padding_token=self.padding_token,
117                       mask_token=self.mask_token,
118                       no_mask_tokens=self.no_mask_tokens,
119                       n_tokens=self.n_tokens,
120                       masking_prob=self.masking_prob,
121                       randomize_prob=self.randomize_prob,
122                       no_change_prob=self.no_change_prob)

Accuracy metric (ignore the labels equal to [PAD] )

125        self.accuracy = Accuracy(ignore_index=self.padding_token)

Cross entropy loss (ignore the labels equal to [PAD] )

127        self.loss_func = nn.CrossEntropyLoss(ignore_index=self.padding_token)

129        super().init()

Training or validation step

131    def step(self, batch: any, batch_idx: BatchIndex):

Move the input to the device

137        data = batch[0].to(self.device)

Update global step (number of tokens processed) when in training mode

140        if self.mode.is_train:
141            tracker.add_global_step(data.shape[0] * data.shape[1])

Get the masked input and labels

144        with torch.no_grad():
145            data, labels = self.mlm(data)

Whether to capture model outputs

148        with self.mode.update(is_log_activations=batch_idx.is_last):

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.

152            output, *_ = self.model(data)

Calculate and log the loss

155        loss = self.loss_func(output.view(-1, output.shape[-1]), labels.view(-1))
156        tracker.add("loss.", loss)

Calculate and log accuracy

159        self.accuracy(output, labels)
160        self.accuracy.track()

Train the model

163        if self.mode.is_train:

Calculate gradients

165            loss.backward()

Clip gradients

167            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)

Take optimizer step

169            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

171            if batch_idx.is_last:
172                tracker.add('model', self.model)

Clear the gradients

174            self.optimizer.zero_grad()

Save the tracked metrics


Sampling function to generate samples periodically while training

179    @torch.no_grad()
180    def sample(self):

Empty tensor for data filled with [PAD] .

186        data = torch.full((self.seq_len, len(self.prompt)), self.padding_token, dtype=torch.long)

Add the prompts one by one

188        for i, p in enumerate(self.prompt):

Get token indexes

190            d = self.text.text_to_i(p)

Add to the tensor

192            s = min(self.seq_len, len(d))
193            data[:s, i] = d[:s]

Move the tensor to current device

195        data =

Get masked input and labels

198        data, labels = self.mlm(data)

Get model outputs

200        output, *_ = self.model(data)

Print the samples generated

203        for j in range(data.shape[1]):

Collect output from printing

205            log = []

For each token

207            for i in range(len(data)):

If the label is not [PAD]

209                if labels[i, j] != self.padding_token:

Get the prediction

211                    t = output[i, j].argmax().item()

If it's a printable character

213                    if t < len(self.text.itos):

Correct prediction

215                        if t == labels[i, j]:
216                            log.append((self.text.itos[t], Text.value))

Incorrect prediction

218                        else:
219                            log.append((self.text.itos[t], Text.danger))

If it's not a printable character

221                    else:
222                        log.append(('*', Text.danger))

If the label is [PAD] (unmasked) print the original.

224                elif data[i, j] < len(self.text.itos):
225                    log.append((self.text.itos[data[i, j]], Text.subtle))


228            logger.log(log)

Number of tokens including [PAD] and [MASK]

232def n_tokens_mlm(c: Configs):
236    return c.text.n_tokens + 2

Transformer configurations

240def _transformer_configs(c: Configs):
247    conf = TransformerConfigs()

Set the vocabulary sizes for embeddings and generating logits

249    conf.n_src_vocab = c.n_tokens
250    conf.n_tgt_vocab = c.n_tokens

Embedding size

252    conf.d_model = c.d_model

255    return conf

Create classification model

259def _model(c: Configs):
263    m = TransformerMLM(encoder=c.transformer.encoder,
264                       src_embed=c.transformer.src_embed,
265                       generator=c.transformer.generator).to(c.device)
267    return m
270def main():

Create experiment

272    experiment.create(name="mlm")

Create configs

274    conf = Configs()

Override configurations

276    experiment.configs(conf, {

Batch size

278        'batch_size': 64,

Sequence length of . We use a short sequence length to train faster. Otherwise it takes forever to train.

281        'seq_len': 32,

Train for 1024 epochs.

284        'epochs': 1024,

Switch between training and validation for times per epoch

287        'inner_iterations': 1,

Transformer configurations (same as defaults)

290        'd_model': 128,
291        'transformer.ffn.d_ff': 256,
292        'transformer.n_heads': 8,
293        'transformer.n_layers': 6,
296        'optimizer.optimizer': 'Noam',
297        'optimizer.learning_rate': 1.,
298    })

Set models for saving and loading

301    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

304    with experiment.start():

Run training


310if __name__ == '__main__':
311    main()