これは、論文「BERT:言語理解のためのディープ双方向トランスフォーマーの事前トレーニング」で紹介されているBERTモデルの事前トレーニングに使用されるマスク言語モデル(MLM)のPyTorch実装です。
BERT モデルはトランスモデルです。この論文では、MLMと次の文章予測を使用してモデルを事前にトレーニングしています。ここではMLMを実装しただけです
。次の文予測では、A
モデルに2つの文が与えられ、B
A
モデルが実際のテキストに続く文かどうかをバイナリ予測します。B
モデルには、50% の確率で実際の文のペアが入力され、50% の確率でランダムなペアが入力されます。この分類はMLMを適用する際に行われます。ここではこれを実装していません。
これにより、トークンのパーセンテージがランダムにマスクされ、マスクされたトークンを予測するようにモデルをトレーニングします。[MASK]
トークンの15%を特別なトークンに置き換えることでマスクします。
損失は、マスクされたトークンの予測のみに基づいて計算されます。[MASK]
その時点ではトークンがないため、微調整や実際の使用中に問題が発生します。したがって、意味のある表現が得られない可能性があります
これを克服するには、マスクされたトークンの 10% が元のトークンに置き換えられ、さらに 10% のマスクされたトークンがランダムなトークンに置き換えられます。これにより、[MASK]
その位置の入力トークンがaであるかどうかに関係なく、実際のトークンについて表現するようにモデルをトレーニングします。また、ランダムなトークンに置き換えると、コンテキストからの情報も含む表現になります。ランダムに置き換えられたトークンを修正するにはコンテキストを使用する必要があるためです。
MLM はトレーニング信号が小さいため、自己回帰モデルよりもトレーニングが困難です。つまり、サンプルごとにトレーニングされる予測の割合はごくわずかです。
もう一つの問題は、モデルが双方向なので、どのトークンも他のトークンを見ることができるということです。これにより、「クレジットの割り当て」が難しくなります。キャラクターレベルのモデルが予測しようとしているとしましょうhome *s where i want to be
。少なくともトレーニングの初期段階では、なぜ置換が必要なのかを理解するのは非常に難しいでしょう。文章全体から何でもかまいません。*
i
一方、自己回帰設定では、h
o
hom
モデルは予測や予測などに使用するだけで済みますe
。そのため、モデルは最初に短いコンテキストで予測を開始し、後で長いコンテキストの使用方法を学習します。MLMにはこの問題があるため、最初は短いシーケンス長から始めて、後で長いシーケンス長を使用する方がトレーニングがはるかに速くなります
65from typing import List
66
67import torch
70class MLM:
padding_token
[PAD]
パディングトークンです。これを使って、損失計算に使用してはいけないラベルにマークを付けます。mask_token
[MASK]
マスキングトークンです。no_mask_tokens
マスクしてはいけないトークンのリストです。これは、分類などの別のタスクで同時にMLMをトレーニングしていて、[CLS]
マスクしてはいけないようなトークンがある場合に便利ですn_tokens
トークンの総数 (ランダムトークンの生成に使用)masking_prob
はマスキング確率ですrandomize_prob
ランダムなトークンに置き換えられる確率ですno_change_prob
は元のトークンと交換する確率です77 def __init__(self, *,
78 padding_token: int, mask_token: int, no_mask_tokens: List[int], n_tokens: int,
79 masking_prob: float = 0.15, randomize_prob: float = 0.1, no_change_prob: float = 0.1,
80 ):
93 self.n_tokens = n_tokens
94 self.no_change_prob = no_change_prob
95 self.randomize_prob = randomize_prob
96 self.masking_prob = masking_prob
97 self.no_mask_tokens = no_mask_tokens + [padding_token, mask_token]
98 self.padding_token = padding_token
99 self.mask_token = mask_token
x
は入力トークンシーケンスのバッチです。long
[seq_len, batch_size]
形のあるタイプのテンソルです101 def __call__(self, x: torch.Tensor):
masking_prob
トークンのマスク
108 full_mask = torch.rand(x.shape, device=x.device) < self.masking_prob
[マスク解除] no_mask_tokens
110 for t in self.no_mask_tokens:
111 full_mask &= x != t
トークンを元のトークンと交換するためのマスク
114 unchanged = full_mask & (torch.rand(x.shape, device=x.device) < self.no_change_prob)
トークンをランダムトークンに置き換えるためのマスク
116 random_token_mask = full_mask & (torch.rand(x.shape, device=x.device) < self.randomize_prob)
ランダムトークンに置き換えられるトークンのインデックス
118 random_token_idx = torch.nonzero(random_token_mask, as_tuple=True)
各場所のランダムトークン
120 random_tokens = torch.randint(0, self.n_tokens, (len(random_token_idx[0]),), device=x.device)
に置き換えられる予定のトークンの最後のセット [MASK]
122 mask = full_mask & ~random_token_mask & ~unchanged
ラベルの入力のクローンを作成
125 y = x.clone()
[MASK]
トークンで置換。元のトークンが変更されないトークンや、ランダムなトークンに置き換えられるトークンは含まれないことに注意してください。
130 x.masked_fill_(mask, self.mask_token)
ランダムトークンの割り当て
132 x[random_token_idx] = random_tokens
[PAD]
ラベル内の他のすべての場所にトークンを割り当てます。[PAD]
と等しいラベルは損失には使用されません。
136 y.masked_fill_(~full_mask, self.padding_token)
マスクされた入力とラベルを返す
139 return x, y