Masked Language Model (MLM)

This is a PyTorch implementation of the Masked Language Model (MLM) used to pre-train the BERT model introduced in the paper BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding.

BERT Pretraining

BERT model is a transformer model. The paper pre-trains the model using MLM and with next sentence prediction. We have only implemented MLM here.

Next sentence prediction

In next sentence prediction, the model is given two sentences A and B and the model makes a binary prediction whether B is the sentence that follows A in the actual text. The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time. This classification is done while applying MLM. We haven’t implemented this here.

Masked LM

This masks a percentage of tokens at random and trains the model to predict the masked tokens. They mask 15% of the tokens by replacing them with a special [MASK] token.

The loss is computed on predicting the masked tokens only. This causes a problem during fine-tuning and actual usage since there are no [MASK] tokens at that time. Therefore we might not get any meaningful representations.

To overcome this 10% of the masked tokens are replaced with the original token, and another 10% of the masked tokens are replaced with a random token. This trains the model to give representations about the actual token whether or not the input token at that position is a [MASK]. And replacing with a random token causes it to give a representation that has information from the context as well; because it has to use the context to fix randomly replaced tokens.


MLMs are harder to train than autoregressive models because they have a smaller training signal. i.e. only a small percentage of predictions are trained per sample.

Another problem is since the model is bidirectional, any token can see any other token. This makes the “credit assignment” harder. Let’s say you have the character level model trying to predict home *s where i want to be. At least during the early stages of the training, it’ll be super hard to figure out why the replacement for * should be i, it could be anything from the whole sentence. Whilst, in an autoregressive setting the model will only have to use h to predict o and hom to predict e and so on. So the model will initially start predicting with a shorter context first and then learn to use longer contexts later. Since MLMs have this problem it’s a lot faster to train if you start with a smaller sequence length initially and then use a longer sequence length later.

Here is the training code for a simple MLM model.

View Run

67from typing import List
69import torch

Masked LM (MLM)

This class implements the masking procedure for a given batch of token sequences.

72class MLM:
  • padding_token is the padding token `[PAD]. We will use this to mark the labels that shouldn’t be used for loss calculation.
  • mask_token is the masking token [MASK].
  • no_mask_tokens is a list of tokens that should not be masked. This is useful if we are training the MLM with another task like classification at the same time, and we have tokens such as [CLS] that shouldn’t be masked.
  • n_tokens total number of tokens (used for generating random tokens)
  • masking_prob is the masking probability
  • randomize_prob is the probability of replacing with a random token
  • no_change_prob is the probability of replacing with original token
79    def __init__(self, *,
80                 padding_token: int, mask_token: int, no_mask_tokens: List[int], n_tokens: int,
81                 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
82                 ):
95        self.n_tokens = n_tokens
96        self.no_change_prob = no_change_prob
97        self.randomize_prob = randomize_prob
98        self.masking_prob = masking_prob
99        self.no_mask_tokens = no_mask_tokens + [padding_token, mask_token]
100        self.padding_token = padding_token
101        self.mask_token = mask_token
  • x is the batch of input token sequences. It’s a tensor of type long with shape [seq_len, batch_size].
103    def __call__(self, x: torch.Tensor):

Mask masking_prob of tokens

110        full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob

Unmask no_mask_tokens

112        for t in self.no_mask_tokens:
113            full_mask &= x != t

A mask for tokens to be replaced with original tokens

116        unchanged = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)

A mask for tokens to be replaced with a random token

118        random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)

Indexes of tokens to be replaced with random tokens

120        random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)

Random tokens for each of the locations

122        random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)

The final set of tokens that are going to be replaced by [MASK]

124        mask = full_mask & ~random_token_mask & ~unchanged

Make a clone of the input for the labels

127        y = x.clone()

Replace with [MASK] tokens; note that this doesn’t include the tokens that will have the original token unchanged and those that get replace with a random token.

132        x.masked_fill_(mask, self.mask_token)

Assign random tokens

134        x[random_token_idx] = random_tokens

Assign token [PAD] to all the other locations in the labels. The labels equal to [PAD] will not be used in the loss.

138        y.masked_fill_(~full_mask, self.padding_token)

Return the masked input and the labels

141        return x, y