蒙面语言模型 (MLM)

这是 PyTorch 实现的掩码语言模型 (MLM),用于预训白文《BERT:深度双向转换器预训练 BERT》中介绍的 BERT 模型。

BERT 预训练

BERT 模型是变压器模型。本文使用 MLM 和下一句预测对模型进行了预训练。我们只在这里实施了传销。

下一句预测

下一个句子预测中,给出两个句子,A B 然后模型对实际文本A 中后面的句子是否B 是后面的句子进行二进制预测。该模型有 50% 的时间为实际句子对,50% 的时间为随机句对。这种分类是在应用传销时完成的。我们还没有在这里实现这一点。

Masked LM

这会随机掩盖一定比例的代币,并训练模型预测被掩码的代币。他们通过用特殊代币替换15%的代[MASK] 币来掩盖它们。

损失仅通过预测被掩码的代币来计算。这在微调和实际使用过程中会导致问题,因为当时没有[MASK] 令牌。因此,我们可能得不到任何有意义的陈述。

为了克服这个问题,10%的蒙面代币被替换为原始代币,另外 10%的蒙面代币被随机代币所取代。无论该位置的输入代币是否为,这都会训练模型给出有关实际代币的表现形式[MASK] 。用随机代币替换会使它给出的表现形式也包含来自上下文的信息;因为它必须使用上下文来修复随机替换的标记。

训练

MLM 比自回归模型更难训练,因为它们的训练信号较小。也就是说,每个样本只训练了一小部分的预测。

另一个问题是,由于该模型是双向的,因此任何代币都可以看到任何其他代币。这使得 “信用分配” 变得更加困难。假设你有角色等级模型想要预测home *s where i want to be 。至少在训练的早期阶段,很难弄清楚为什么要用* 它来代替i ,可能是整句话中的任何东西。而在自回归环境中,模型只h 需要用于预测o ehom 预测等等。因此,该模型最初将首先在较短的上下文中开始预测,然后学会使用较长的上下文进行预测。由于 MLM 有这个问题,如果你一开始使用较小的序列长度,然后再使用更长的序列长度,那么训练速度会快得多。

这是简单 MLM 模型的训练代码

65from typing import List
66
67import torch

Masked LM (传销)

该类实现给定批次令牌序列的掩码过程。

70class MLM:
  • padding_token 是填充令牌[PAD] 。我们将用它来标记不应用于损失计算的标签。
  • mask_token 是掩码令牌[MASK]
  • no_mask_tokens 是不应该被屏蔽的令牌的列表。如果我们同时使用分类等其他任务来训练 MLM,并且我们有这样的令牌不应该被掩盖,那么[CLS] 这很有用。
  • n_tokens 代币总数(用于生成随机令牌)
  • masking_prob 是掩蔽概率
  • randomize_prob 是用随机代币替换的概率
  • no_change_prob 是用原始代币替换的概率
77    def __init__(self, *,
78                 padding_token: int, mask_token: int, no_mask_tokens: List[int], n_tokens: int,
79                 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
80                 ):
93        self.n_tokens = n_tokens
94        self.no_change_prob = no_change_prob
95        self.randomize_prob = randomize_prob
96        self.masking_prob = masking_prob
97        self.no_mask_tokens = no_mask_tokens + [padding_token, mask_token]
98        self.padding_token = padding_token
99        self.mask_token = mask_token
  • x 是一批输入令牌序列。它是一个long 带形状的张量[seq_len, batch_size]
101    def __call__(self, x: torch.Tensor):

代币masking_prob 面具

108        full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob

揭露面具no_mask_tokens

110        for t in self.no_mask_tokens:
111            full_mask &= x != t

将令牌替换为原始令牌的掩码

114        unchanged = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)

将令牌替换为随机令牌的掩码

116        random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)

要替换为随机令牌的令牌索引

118        random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)

每个地点的随机代币

120        random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)

最后一组将被替换的令牌[MASK]

122        mask = full_mask & ~random_token_mask & ~unchanged

克隆标签的输入

125        y = x.clone()

替换为[MASK] 令牌;请注意,这不包括原始令牌保持不变的令牌和被随机令牌替换的令牌。

130        x.masked_fill_(mask, self.mask_token)

分配随机代币

132        x[random_token_idx] = random_tokens

将令牌[PAD] 分配给标签中的所有其他位置。等于的标签[PAD] 将不会用于亏损。

136        y.masked_fill_(~full_mask, self.padding_token)

返回屏蔽的输入和标签

139        return x, y