这是 PyTorch 实现的掩码语言模型 (MLM),用于预训白文《BERT:深度双向转换器预训练 BERT》中介绍的 BERT 模型。
BERT 模型是变压器模型。本文使用 MLM 和下一句预测对模型进行了预训练。我们只在这里实施了传销。
在下一个句子预测中,给出两个句子,A
B
然后模型对实际文本A
中后面的句子是否B
是后面的句子进行二进制预测。该模型有 50% 的时间为实际句子对,50% 的时间为随机句对。这种分类是在应用传销时完成的。我们还没有在这里实现这一点。
这会随机掩盖一定比例的代币,并训练模型预测被掩码的代币。他们通过用特殊代币替换15%的代[MASK]
币来掩盖它们。
损失仅通过预测被掩码的代币来计算。这在微调和实际使用过程中会导致问题,因为当时没有[MASK]
令牌。因此,我们可能得不到任何有意义的陈述。
为了克服这个问题,10%的蒙面代币被替换为原始代币,另外 10%的蒙面代币被随机代币所取代。无论该位置的输入代币是否为,这都会训练模型给出有关实际代币的表现形式[MASK]
。用随机代币替换会使它给出的表现形式也包含来自上下文的信息;因为它必须使用上下文来修复随机替换的标记。
MLM 比自回归模型更难训练,因为它们的训练信号较小。也就是说,每个样本只训练了一小部分的预测。
另一个问题是,由于该模型是双向的,因此任何代币都可以看到任何其他代币。这使得 “信用分配” 变得更加困难。假设你有角色等级模型想要预测home *s where i want to be
。至少在训练的早期阶段,很难弄清楚为什么要用*
它来代替i
,可能是整句话中的任何东西。而在自回归环境中,模型只h
需要用于预测o
e
和hom
预测等等。因此,该模型最初将首先在较短的上下文中开始预测,然后学会使用较长的上下文进行预测。由于 MLM 有这个问题,如果你一开始使用较小的序列长度,然后再使用更长的序列长度,那么训练速度会快得多。
这是简单 MLM 模型的训练代码。
65from typing import List
66
67import torch
70class MLM:
padding_token
是填充令牌[PAD]
。我们将用它来标记不应用于损失计算的标签。mask_token
是掩码令牌[MASK]
。no_mask_tokens
是不应该被屏蔽的令牌的列表。如果我们同时使用分类等其他任务来训练 MLM,并且我们有这样的令牌不应该被掩盖,那么[CLS]
这很有用。n_tokens
代币总数(用于生成随机令牌)masking_prob
是掩蔽概率randomize_prob
是用随机代币替换的概率no_change_prob
是用原始代币替换的概率77 def __init__(self, *,
78 padding_token: int, mask_token: int, no_mask_tokens: List[int], n_tokens: int,
79 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
80 ):
93 self.n_tokens = n_tokens
94 self.no_change_prob = no_change_prob
95 self.randomize_prob = randomize_prob
96 self.masking_prob = masking_prob
97 self.no_mask_tokens = no_mask_tokens + [padding_token, mask_token]
98 self.padding_token = padding_token
99 self.mask_token = mask_token
x
是一批输入令牌序列。它是一个long
带形状的张量[seq_len, batch_size]
。101 def __call__(self, x: torch.Tensor):
代币masking_prob
面具
108 full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
揭露面具no_mask_tokens
110 for t in self.no_mask_tokens:
111 full_mask &= x != t
将令牌替换为原始令牌的掩码
114 unchanged = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
将令牌替换为随机令牌的掩码
116 random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
要替换为随机令牌的令牌索引
118 random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
每个地点的随机代币
120 random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
最后一组将被替换的令牌[MASK]
122 mask = full_mask & ~random_token_mask & ~unchanged
克隆标签的输入
125 y = x.clone()
替换为[MASK]
令牌;请注意,这不包括原始令牌保持不变的令牌和被随机令牌替换的令牌。
130 x.masked_fill_(mask, self.mask_token)
分配随机代币
132 x[random_token_idx] = random_tokens
将令牌[PAD]
分配给标签中的所有其他位置。等于的标签[PAD]
将不会用于亏损。
136 y.masked_fill_(~full_mask, self.padding_token)
返回屏蔽的输入和标签
139 return x, y