13import math
14
15import torch
16import torch.nn as nn
17
18from labml_nn.utils import clone_module_list
19from .feed_forward import FeedForward
20from .mha import MultiHeadAttention
21from .positional_encoding import get_positional_encoding
24class EmbeddingsWithPositionalEncoding(nn.Module):
31 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
32 super().__init__()
33 self.linear = nn.Embedding(n_vocab, d_model)
34 self.d_model = d_model
35 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
37 def forward(self, x: torch.Tensor):
38 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
39 return self.linear(x) * math.sqrt(self.d_model) + pe
42class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
49 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
50 super().__init__()
51 self.linear = nn.Embedding(n_vocab, d_model)
52 self.d_model = d_model
53 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
55 def forward(self, x: torch.Tensor):
56 pe = self.positional_encodings[:x.shape[0]]
57 return self.linear(x) * math.sqrt(self.d_model) + pe
これは、エンコーダ層またはデコーダ層として機能できます。
🗒 論文を含む一部の実装では、層の正規化が行われる場所に違いがあるようです。ここでは、アテンションネットワークとフィードフォワードネットワークの前に層の正規化を行い、元の残差ベクトルを追加します。別の方法は、残差を追加した後に層の正規化を行うことです。しかし、トレーニング中は安定性が低いことがわかりました。これについての詳細な議論は、「トランスフォーマーアーキテクチャにおける層正規化について」という論文に記載されています
。60class TransformerLayer(nn.Module):
d_model
トークンの埋め込みサイズですself_attn
セルフアテンションモジュールですsrc_attn
ソース・アテンション・モジュールです (これをデコーダで使用する場合)feed_forward
フィードフォワードモジュールですdropout_prob
セルフアテンションとFFNの後に脱落する確率です78 def __init__(self, *,
79 d_model: int,
80 self_attn: MultiHeadAttention,
81 src_attn: MultiHeadAttention = None,
82 feed_forward: FeedForward,
83 dropout_prob: float):
91 super().__init__()
92 self.size = d_model
93 self.self_attn = self_attn
94 self.src_attn = src_attn
95 self.feed_forward = feed_forward
96 self.dropout = nn.Dropout(dropout_prob)
97 self.norm_self_attn = nn.LayerNorm([d_model])
98 if self.src_attn is not None:
99 self.norm_src_attn = nn.LayerNorm([d_model])
100 self.norm_ff = nn.LayerNorm([d_model])
入力をフィードフォワード層に保存するかどうか
102 self.is_save_ff_input = False
104 def forward(self, *,
105 x: torch.Tensor,
106 mask: torch.Tensor,
107 src: torch.Tensor = None,
108 src_mask: torch.Tensor = None):
セルフアテンションを行う前にベクトルを正規化してください
110 z = self.norm_self_attn(x)
自己注意を向ける。つまり、キーと値は自己からのものだ
112 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
セルフアテンションの結果を追加
114 x = x + self.dropout(self_attn)
ソースが提供されている場合は、ソースに注目して結果を取得します。これは、エンコーダー出力に注目するデコーダーレイヤーがある場合です
。119 if src is not None:
ベクトルを正規化
121 z = self.norm_src_attn(x)
ソースに注意。つまり、キーと値はソースからのものです
123 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
ソースアテンション結果の追加
125 x = x + self.dropout(attn_src)
フィードフォワード用に正規化
128 z = self.norm_ff(x)
指定されている場合、入力をフィードフォワード層に保存します
130 if self.is_save_ff_input:
131 self.ff_input = z.clone()
フィードフォワードネットワークを通過
133 ff = self.feed_forward(z)
フィードフォワードの結果を追加し直す
135 x = x + self.dropout(ff)
136
137 return x
140class Encoder(nn.Module):
147 def __init__(self, layer: TransformerLayer, n_layers: int):
148 super().__init__()
トランスレイヤーのコピーを作成
150 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
152 self.norm = nn.LayerNorm([layer.size])
154 def forward(self, x: torch.Tensor, mask: torch.Tensor):
各変圧器層に通す
156 for layer in self.layers:
157 x = layer(x=x, mask=mask)
最後に、ベクトルを正規化します。
159 return self.norm(x)
162class Decoder(nn.Module):
169 def __init__(self, layer: TransformerLayer, n_layers: int):
170 super().__init__()
トランスレイヤーのコピーを作成
172 self.layers = clone_module_list(layer, n_layers)
最終正規化レイヤー
174 self.norm = nn.LayerNorm([layer.size])
176 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
各変圧器層に通す
178 for layer in self.layers:
179 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
最後に、ベクトルを正規化します。
181 return self.norm(x)
184class Generator(nn.Module):
194 def __init__(self, n_vocab: int, d_model: int):
195 super().__init__()
196 self.projection = nn.Linear(d_model, n_vocab)
198 def forward(self, x):
199 return self.projection(x)
202class EncoderDecoder(nn.Module):
209 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
210 super().__init__()
211 self.encoder = encoder
212 self.decoder = decoder
213 self.src_embed = src_embed
214 self.tgt_embed = tgt_embed
215 self.generator = generator
これは彼らのコードからすると重要でした。Glorot /fan_avg を使用してパラメーターを初期化します
。219 for p in self.parameters():
220 if p.dim() > 1:
221 nn.init.xavier_uniform_(p)
223 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
ソースをエンコーダで実行
225 enc = self.encode(src, src_mask)
デコーダーによるエンコーディングとターゲットの実行
227 return self.decode(enc, src_mask, tgt, tgt_mask)
229 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
230 return self.encoder(self.src_embed(src), src_mask)
232 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
233 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)