9import copy
10
11import torch.nn as nn
12
13from labml.configs import BaseConfigs, option, calculate, aggregate
14from labml_helpers.module import Module
15from .feed_forward import FeedForward
16from .mha import MultiHeadAttention
17from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
18 Encoder, Decoder, Generator, EncoderDecoder
21class FeedForwardConfigs(BaseConfigs):
位置ごとのフィードフォワード層
31 ffn: FeedForward
埋め込みに含まれる機能の数
33 d_model: int
隠れレイヤーに含まれるフィーチャの数
35 d_ff: int = 2048
脱落確率
37 dropout: float = 0.1
位置単位フィードフォワード層での活性化
39 activation: nn.Module = 'ReLU'
FFN レイヤーをゲートすべきかどうか
41 is_gated: bool = False
最初の完全接続層に学習可能なバイアスを付けるべきかどうか
43 bias1: bool = True
2 番目の完全接続層に学習可能なバイアスを設定すべきかどうか
45 bias2: bool = True
ゲートの全接続層に学習可能なバイアスを設けるべきかどうか
47 bias_gate: bool = False
定義済みの GLU バリアント
49 glu_variant: str = 'none'
52@option(FeedForwardConfigs.activation, 'ReLU')
53def _ffn_activation_relu():
59 return nn.ReLU()
62@option(FeedForwardConfigs.activation, 'GELU')
63def _ffn_activation_gelu():
71 return nn.GELU()
74@option(FeedForwardConfigs.ffn, 'default')
75def _feed_forward(c: FeedForwardConfigs):
79 return FeedForward(c.d_model, c.d_ff,
80 dropout=c.dropout,
81 activation=c.activation,
82 is_gated=c.is_gated,
83 bias1=c.bias1,
84 bias2=c.bias2,
85 bias_gate=c.bias_gate)
95aggregate(FeedForwardConfigs.glu_variant, 'GLU',
96 (FeedForwardConfigs.is_gated, True),
97 (FeedForwardConfigs.bias1, False),
98 (FeedForwardConfigs.bias2, False),
99 (FeedForwardConfigs.bias_gate, False),
100 (FeedForwardConfigs.activation, nn.Sigmoid()))
105aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
106 (FeedForwardConfigs.is_gated, True),
107 (FeedForwardConfigs.bias1, False),
108 (FeedForwardConfigs.bias2, False),
109 (FeedForwardConfigs.bias_gate, False),
110 (FeedForwardConfigs.activation, nn.Identity()))
115aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
116 (FeedForwardConfigs.is_gated, True),
117 (FeedForwardConfigs.bias1, False),
118 (FeedForwardConfigs.bias2, False),
119 (FeedForwardConfigs.bias_gate, False),
120 (FeedForwardConfigs.activation, nn.ReLU()))
125aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
126 (FeedForwardConfigs.is_gated, True),
127 (FeedForwardConfigs.bias1, False),
128 (FeedForwardConfigs.bias2, False),
129 (FeedForwardConfigs.bias_gate, False),
130 (FeedForwardConfigs.activation, nn.GELU()))
136aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
137 (FeedForwardConfigs.is_gated, True),
138 (FeedForwardConfigs.bias1, False),
139 (FeedForwardConfigs.bias2, False),
140 (FeedForwardConfigs.bias_gate, False),
141 (FeedForwardConfigs.activation, nn.SiLU()))
144class TransformerConfigs(BaseConfigs):
アテンションヘッドの数
156 n_heads: int = 8
変圧器埋め込みサイズ
158 d_model: int = 512
レイヤー数
160 n_layers: int = 6
脱落確率
162 dropout: float = 0.1
ソースボキャブラリーのトークン数 (トークンの埋め込み用)
164 n_src_vocab: int
ターゲットボキャブラリ内のトークンの数 (予測用のロジットを生成するため)
166 n_tgt_vocab: int
エンコーダのセルフアテンション
169 encoder_attn: MultiHeadAttention = 'mha'
デコーダーのセルフアテンション
171 decoder_attn: MultiHeadAttention = 'mha'
デコーダメモリアテンション
173 decoder_mem_attn: MultiHeadAttention = 'mha'
設定可能なフィードフォワード層
176 ffn: FeedForwardConfigs
エンコーダ層
179 encoder_layer: TransformerLayer = 'default'
デコーダー層
181 decoder_layer: TransformerLayer = 'default'
複数のエンコーダー層で構成されるエンコーダー
184 encoder: Encoder = 'default'
複数のデコーダー層で構成されるエンコーダー
186 decoder: Decoder = 'default'
ソースの埋め込みレイヤー
189 src_embed: Module = 'fixed_pos'
ターゲット用埋め込みレイヤー (デコーダー用)
191 tgt_embed: Module = 'fixed_pos'
予測用ロジット・ジェネレーター
194 generator: Generator = 'default'
エンコーダ/デコーダ
197 encoder_decoder: EncoderDecoder
201def _mha(c: TransformerConfigs):
202 return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
203
204
205calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
206calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
207calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)
211def _relative_mha(c: TransformerConfigs):
212 from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
213 return RelativeMultiHeadAttention(c.n_heads, c.d_model)
214
215
216calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
217calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
218calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)
フィードフォワード層構成の作成
221@option(TransformerConfigs.ffn, 'default')
222def _feed_forward(c: TransformerConfigs):
226 conf = FeedForwardConfigs()
227 conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
228 conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
229 return conf
エンコーダ層
232@option(TransformerConfigs.encoder_layer, 'default')
233def _encoder_layer(c: TransformerConfigs):
237 return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
238 src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
239 dropout_prob=c.dropout)
デコーダー層
242@option(TransformerConfigs.decoder_layer, 'default')
243def _decoder_layer(c: TransformerConfigs):
247 return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
248 src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
249 dropout_prob=c.dropout)
エンコーダー
252@option(TransformerConfigs.encoder, 'default')
253def _encoder(c: TransformerConfigs):
257 return Encoder(c.encoder_layer, c.n_layers)
デコーダー
260@option(TransformerConfigs.decoder, 'default')
261def _decoder(c: TransformerConfigs):
265 return Decoder(c.decoder_layer, c.n_layers)
ロジット・ジェネレーター
268@option(TransformerConfigs.generator, 'default')
269def _generator(c: TransformerConfigs):
273 return Generator(c.n_tgt_vocab, c.d_model)
277@option(TransformerConfigs.src_embed, 'fixed_pos')
278def _src_embed_with_positional(c: TransformerConfigs):
282 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)
固定位置エンコーディングによるターゲット埋め込み
285@option(TransformerConfigs.tgt_embed, 'fixed_pos')
286def _tgt_embed_with_positional(c: TransformerConfigs):
290 return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)
294@option(TransformerConfigs.src_embed, 'learned_pos')
295def _src_embed_with_learned_positional(c: TransformerConfigs):
299 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)
学習した位置エンコーディングによるターゲット埋め込み
302@option(TransformerConfigs.tgt_embed, 'learned_pos')
303def _tgt_embed_with_learned_positional(c: TransformerConfigs):
307 return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)
311@option(TransformerConfigs.src_embed, 'no_pos')
312def _src_embed_without_positional(c: TransformerConfigs):
316 return nn.Embedding(c.n_src_vocab, c.d_model)
319@option(TransformerConfigs.tgt_embed, 'no_pos')
320def _tgt_embed_without_positional(c: TransformerConfigs):
321 return nn.Embedding(c.n_tgt_vocab, c.d_model)
322
323
324@option(TransformerConfigs.encoder_decoder, 'default')
325def _encoder_decoder(c: TransformerConfigs):
326 return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)