13import math
14
15import torch
16import torch.nn as nn
17
18from labml_nn.utils import clone_module_list
19from .feed_forward import FeedForward
20from .mha import MultiHeadAttention
21from .positional_encoding import get_positional_encoding
31 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
32 super().__init__()
33 self.linear = nn.Embedding(n_vocab, d_model)
34 self.d_model = d_model
35 self.register_buffer('positional_encodings', get_positional_encoding(d_model, max_len))
37 def forward(self, x: torch.Tensor):
38 pe = self.positional_encodings[:x.shape[0]].requires_grad_(False)
39 return self.linear(x) * math.sqrt(self.d_model) + pe
42class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
49 def __init__(self, d_model: int, n_vocab: int, max_len: int = 5000):
50 super().__init__()
51 self.linear = nn.Embedding(n_vocab, d_model)
52 self.d_model = d_model
53 self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
55 def forward(self, x: torch.Tensor):
56 pe = self.positional_encodings[:x.shape[0]]
57 return self.linear(x) * math.sqrt(self.d_model) + pe
60class TransformerLayer(nn.Module):
d_model
是 token 嵌入大小self_attn
是自注意力模块src_attn
是注意力模块源(当它用于解码器时)feed_forward
是前馈模块dropout_prob
是自注意力和 FFN 后的 Dropout 率69 def __init__(self, *,
70 d_model: int,
71 self_attn: MultiHeadAttention,
72 src_attn: MultiHeadAttention = None,
73 feed_forward: FeedForward,
74 dropout_prob: float):
82 super().__init__()
83 self.size = d_model
84 self.self_attn = self_attn
85 self.src_attn = src_attn
86 self.feed_forward = feed_forward
87 self.dropout = nn.Dropout(dropout_prob)
88 self.norm_self_attn = nn.LayerNorm([d_model])
89 if self.src_attn is not None:
90 self.norm_src_attn = nn.LayerNorm([d_model])
91 self.norm_ff = nn.LayerNorm([d_model])
是否将输入保存到前馈层
93 self.is_save_ff_input = False
95 def forward(self, *,
96 x: torch.Tensor,
97 mask: torch.Tensor,
98 src: torch.Tensor = None,
99 src_mask: torch.Tensor = None):
在进行自我注意之前对向量进行归一化
101 z = self.norm_self_attn(x)
通过自注意力机制运行,即键和值来自于自身
103 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
添加自注意力结果
105 x = x + self.dropout(self_attn)
如果提供了源数据,则从注意力机制中获取结果。这是指当解码器层关注编码器输出时。
110 if src is not None:
归一化向量
112 z = self.norm_src_attn(x)
关注源数据,即键和值来自源数据
114 attn_src = self.src_attn(query=z, key=src, value=src, mask=src_mask)
添加源关注结果
116 x = x + self.dropout(attn_src)
标准化以进行前馈
119 z = self.norm_ff(x)
如果已指定,则将输入保存到前馈层
121 if self.is_save_ff_input:
122 self.ff_input = z.clone()
通过前馈网络传递
124 ff = self.feed_forward(z)
将前馈结果添加回来
126 x = x + self.dropout(ff)
127
128 return x
131class Encoder(nn.Module):
138 def __init__(self, layer: TransformerLayer, n_layers: int):
139 super().__init__()
制作 Transformer 层的副本
141 self.layers = clone_module_list(layer, n_layers)
最终的归一化层
143 self.norm = nn.LayerNorm([layer.size])
145 def forward(self, x: torch.Tensor, mask: torch.Tensor):
运行每个 Transformer 层
147 for layer in self.layers:
148 x = layer(x=x, mask=mask)
最后,对向量进行归一化
150 return self.norm(x)
153class Decoder(nn.Module):
160 def __init__(self, layer: TransformerLayer, n_layers: int):
161 super().__init__()
制作 Transformer 层的副本
163 self.layers = clone_module_list(layer, n_layers)
最终的归一化层
165 self.norm = nn.LayerNorm([layer.size])
167 def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
运行每个 Transformer 层
169 for layer in self.layers:
170 x = layer(x=x, mask=tgt_mask, src=memory, src_mask=src_mask)
最后,对向量进行归一化
172 return self.norm(x)
175class Generator(nn.Module):
185 def __init__(self, n_vocab: int, d_model: int):
186 super().__init__()
187 self.projection = nn.Linear(d_model, n_vocab)
189 def forward(self, x):
190 return self.projection(x)
193class EncoderDecoder(nn.Module):
200 def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
201 super().__init__()
202 self.encoder = encoder
203 self.decoder = decoder
204 self.src_embed = src_embed
205 self.tgt_embed = tgt_embed
206 self.generator = generator
这是代码中很重要的部分。使用 Glorot/fan_avg 初始化参数。
210 for p in self.parameters():
211 if p.dim() > 1:
212 nn.init.xavier_uniform_(p)
214 def forward(self, src: torch.Tensor, tgt: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor):
通过编码器运行源代码
216 enc = self.encode(src, src_mask)
通过解码器运行编码和目标
218 return self.decode(enc, src_mask, tgt, tgt_mask)
220 def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
221 return self.encoder(self.src_embed(src), src_mask)
223 def decode(self, memory: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
224 return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)