設定可能な変圧器コンポーネント

9import copy
10
11import torch.nn as nn
12
13from labml.configs import BaseConfigs, option, calculate, aggregate
14from labml_helpers.module import Module
15from .feed_forward import FeedForward
16from .mha import MultiHeadAttention
17from .models import EmbeddingsWithPositionalEncoding, EmbeddingsWithLearnedPositionalEncoding, TransformerLayer, \
18    Encoder, Decoder, Generator, EncoderDecoder

FFN コンフィギュレーション

で定義されている位置単位のフィードフォワードネットワークを作成します。feed_forward.py

21class FeedForwardConfigs(BaseConfigs):

位置ごとのフィードフォワード層

31    ffn: FeedForward

埋め込みに含まれる機能の数

33    d_model: int

隠れレイヤーに含まれるフィーチャの数

35    d_ff: int = 2048

脱落確率

37    dropout: float = 0.1

位置単位フィードフォワード層での活性化

39    activation: nn.Module = 'ReLU'

FFN レイヤーをゲートすべきかどうか

41    is_gated: bool = False

最初の完全接続層に学習可能なバイアスを付けるべきかどうか

43    bias1: bool = True

2 番目の完全接続層に学習可能なバイアスを設定すべきかどうか

45    bias2: bool = True

ゲートの全接続層に学習可能なバイアスを設けるべきかどうか

47    bias_gate: bool = False

定義済みの GLU バリアント

49    glu_variant: str = 'none'

ReLU アクティベーション

52@option(FeedForwardConfigs.activation, 'ReLU')
53def _ffn_activation_relu():
59    return nn.ReLU()

GELU アクティベーション

どこ

ガウス誤差線形単位の論文で紹介されました

62@option(FeedForwardConfigs.activation, 'GELU')
63def _ffn_activation_gelu():
71    return nn.GELU()
74@option(FeedForwardConfigs.ffn, 'default')
75def _feed_forward(c: FeedForwardConfigs):
79    return FeedForward(c.d_model, c.d_ff,
80                       dropout=c.dropout,
81                       activation=c.activation,
82                       is_gated=c.is_gated,
83                       bias1=c.bias1,
84                       bias2=c.bias2,
85                       bias_gate=c.bias_gate)

GLU バリアント

これらは、紙のGLUバリアント改良トランスフォーマーで紹介されているように、FFN用のゲート隠れ層を備えたバリアントです。論文で明記されているバイアス用語は省略しています

ゲート付きリニアユニット付きFFN

95aggregate(FeedForwardConfigs.glu_variant, 'GLU',
96          (FeedForwardConfigs.is_gated, True),
97          (FeedForwardConfigs.bias1, False),
98          (FeedForwardConfigs.bias2, False),
99          (FeedForwardConfigs.bias_gate, False),
100          (FeedForwardConfigs.activation, nn.Sigmoid()))

バイリニア隠れ層付きFFN

105aggregate(FeedForwardConfigs.glu_variant, 'Bilinear',
106          (FeedForwardConfigs.is_gated, True),
107          (FeedForwardConfigs.bias1, False),
108          (FeedForwardConfigs.bias2, False),
109          (FeedForwardConfigs.bias_gate, False),
110          (FeedForwardConfigs.activation, nn.Identity()))

RelU ゲート付き FN

115aggregate(FeedForwardConfigs.glu_variant, 'ReGLU',
116          (FeedForwardConfigs.is_gated, True),
117          (FeedForwardConfigs.bias1, False),
118          (FeedForwardConfigs.bias2, False),
119          (FeedForwardConfigs.bias_gate, False),
120          (FeedForwardConfigs.activation, nn.ReLU()))

GELU ゲート付きFFN

125aggregate(FeedForwardConfigs.glu_variant, 'GEGLU',
126          (FeedForwardConfigs.is_gated, True),
127          (FeedForwardConfigs.bias1, False),
128          (FeedForwardConfigs.bias2, False),
129          (FeedForwardConfigs.bias_gate, False),
130          (FeedForwardConfigs.activation, nn.GELU()))

FFN(スウィッシュゲート付き)

どこ

136aggregate(FeedForwardConfigs.glu_variant, 'SwiGLU',
137          (FeedForwardConfigs.is_gated, True),
138          (FeedForwardConfigs.bias1, False),
139          (FeedForwardConfigs.bias2, False),
140          (FeedForwardConfigs.bias_gate, False),
141          (FeedForwardConfigs.activation, nn.SiLU()))

変圧器構成

これは変圧器の構成を定義します。構成はオプション関数を使用して計算されます。これらは遅延ロードされるため、必要なモジュールだけが計算されます

144class TransformerConfigs(BaseConfigs):

アテンションヘッドの数

156    n_heads: int = 8

変圧器埋め込みサイズ

158    d_model: int = 512

レイヤー数

160    n_layers: int = 6

脱落確率

162    dropout: float = 0.1

ソースボキャブラリーのトークン数 (トークンの埋め込み用)

164    n_src_vocab: int

ターゲットボキャブラリ内のトークンの数 (予測用のロジットを生成するため)

166    n_tgt_vocab: int

エンコーダのセルフアテンション

169    encoder_attn: MultiHeadAttention = 'mha'

デコーダーのセルフアテンション

171    decoder_attn: MultiHeadAttention = 'mha'

デコーダメモリアテンション

173    decoder_mem_attn: MultiHeadAttention = 'mha'

設定可能なフィードフォワード層

176    ffn: FeedForwardConfigs

エンコーダ層

179    encoder_layer: TransformerLayer = 'default'

デコーダー層

181    decoder_layer: TransformerLayer = 'default'

複数のエンコーダー層で構成されるエンコーダー

184    encoder: Encoder = 'default'

複数のデコーダー層で構成されるエンコーダー

186    decoder: Decoder = 'default'

ソースの埋め込みレイヤー

189    src_embed: Module = 'fixed_pos'

ターゲット用埋め込みレイヤー (デコーダー用)

191    tgt_embed: Module = 'fixed_pos'

予測用ロジット・ジェネレーター

194    generator: Generator = 'default'

エンコーダ/デコーダ

197    encoder_decoder: EncoderDecoder

マルチヘッド・アテンション

201def _mha(c: TransformerConfigs):
202    return MultiHeadAttention(c.n_heads, c.d_model, dropout_prob=c.dropout)
203
204
205calculate(TransformerConfigs.encoder_attn, 'mha', _mha)
206calculate(TransformerConfigs.decoder_attn, 'mha', _mha)
207calculate(TransformerConfigs.decoder_mem_attn, 'mha', _mha)

相対的なマルチヘッド・アテンション

211def _relative_mha(c: TransformerConfigs):
212    from labml_nn.transformers.xl.relative_mha import RelativeMultiHeadAttention
213    return RelativeMultiHeadAttention(c.n_heads, c.d_model)
214
215
216calculate(TransformerConfigs.encoder_attn, 'relative', _relative_mha)
217calculate(TransformerConfigs.decoder_attn, 'relative', _relative_mha)
218calculate(TransformerConfigs.decoder_mem_attn, 'relative', _relative_mha)

フィードフォワード層構成の作成

221@option(TransformerConfigs.ffn, 'default')
222def _feed_forward(c: TransformerConfigs):
226    conf = FeedForwardConfigs()
227    conf.set_default(FeedForwardConfigs.d_model, func=lambda: c.d_model)
228    conf.set_default(FeedForwardConfigs.dropout, func=lambda: c.dropout)
229    return conf

エンコーダ層

232@option(TransformerConfigs.encoder_layer, 'default')
233def _encoder_layer(c: TransformerConfigs):
237    return TransformerLayer(d_model=c.d_model, self_attn=c.encoder_attn,
238                            src_attn=None, feed_forward=copy.deepcopy(c.ffn.ffn),
239                            dropout_prob=c.dropout)

デコーダー層

242@option(TransformerConfigs.decoder_layer, 'default')
243def _decoder_layer(c: TransformerConfigs):
247    return TransformerLayer(d_model=c.d_model, self_attn=c.decoder_attn,
248                            src_attn=c.decoder_mem_attn, feed_forward=copy.deepcopy(c.ffn.ffn),
249                            dropout_prob=c.dropout)

エンコーダー

252@option(TransformerConfigs.encoder, 'default')
253def _encoder(c: TransformerConfigs):
257    return Encoder(c.encoder_layer, c.n_layers)

デコーダー

260@option(TransformerConfigs.decoder, 'default')
261def _decoder(c: TransformerConfigs):
265    return Decoder(c.decoder_layer, c.n_layers)

ロジット・ジェネレーター

268@option(TransformerConfigs.generator, 'default')
269def _generator(c: TransformerConfigs):
273    return Generator(c.n_tgt_vocab, c.d_model)

固定位置埋め込み

固定位置エンコーディングによるソース埋め込み

277@option(TransformerConfigs.src_embed, 'fixed_pos')
278def _src_embed_with_positional(c: TransformerConfigs):
282    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_src_vocab)

固定位置エンコーディングによるターゲット埋め込み

285@option(TransformerConfigs.tgt_embed, 'fixed_pos')
286def _tgt_embed_with_positional(c: TransformerConfigs):
290    return EmbeddingsWithPositionalEncoding(c.d_model, c.n_tgt_vocab)

位置埋め込みを学んだ

学習した位置エンコーディングによるソース埋め込み

294@option(TransformerConfigs.src_embed, 'learned_pos')
295def _src_embed_with_learned_positional(c: TransformerConfigs):
299    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_src_vocab)

学習した位置エンコーディングによるターゲット埋め込み

302@option(TransformerConfigs.tgt_embed, 'learned_pos')
303def _tgt_embed_with_learned_positional(c: TransformerConfigs):
307    return EmbeddingsWithLearnedPositionalEncoding(c.d_model, c.n_tgt_vocab)

位置指定埋め込みなし

位置エンコーディングなしのソース埋め込み

311@option(TransformerConfigs.src_embed, 'no_pos')
312def _src_embed_without_positional(c: TransformerConfigs):
316    return nn.Embedding(c.n_src_vocab, c.d_model)
319@option(TransformerConfigs.tgt_embed, 'no_pos')
320def _tgt_embed_without_positional(c: TransformerConfigs):
321    return nn.Embedding(c.n_tgt_vocab, c.d_model)
322
323
324@option(TransformerConfigs.encoder_decoder, 'default')
325def _encoder_decoder(c: TransformerConfigs):
326    return EncoderDecoder(c.encoder, c.decoder, c.src_embed, c.tgt_embed, c.generator)