15import torch
16from torch import nn
17
18from labml import experiment
19from labml.configs import option
20from labml_helpers.module import Module
21from labml_nn.experiments.nlp_classification import NLPClassificationConfigs
22from labml_nn.transformers import Encoder
23from labml_nn.transformers import TransformerConfigs26class TransformerClassifier(nn.Module):30 def __init__(self, encoder: Encoder, src_embed: Module, generator: nn.Linear):37 super().__init__()
38 self.src_embed = src_embed
39 self.encoder = encoder
40 self.generator = generator42 def forward(self, x: torch.Tensor):位置エンコーディングによるトークンの埋め込みを取得
44 x = self.src_embed(x)トランスエンコーダー
46 x = self.encoder(x, None)分類用のロジットを取得します。
[CLS]
シーケンスの最後の位置にトークンを設定します。これは、x[-1]
x
形状がどこにあるかによって抽出されます [seq_len, batch_size, d_model]
52 x = self.generator(x[-1])結果を返します(トレーナーはRNNでも使用されるため、2番目の値は状態用です)
56 return x, None59class Configs(NLPClassificationConfigs):分類モデル
68 model: TransformerClassifier変圧器
70 transformer: TransformerConfigs73@option(Configs.transformer)
74def _transformer_configs(c: Configs):81 conf = TransformerConfigs()埋め込みやロジットの生成に使用するボキャブラリーサイズを設定
83 conf.n_src_vocab = c.n_tokens
84 conf.n_tgt_vocab = c.n_tokens87 return conf90@option(TransformerConfigs.encoder_attn)
91def fnet_mix():97 from labml_nn.transformers.fnet import FNetMix
98 return FNetMix()分類モデルの作成
101@option(Configs.model)
102def _model(c: Configs):106 m = TransformerClassifier(c.transformer.encoder,
107 c.transformer.src_embed,
108 nn.Linear(c.d_model, c.n_classes)).to(c.device)
109
110 return m113def main():実験を作成
115 experiment.create(name="fnet")コンフィグの作成
117 conf = Configs()オーバーライド設定
119 experiment.configs(conf, {ワールドレベルのトークナイザーを使う
121 'tokenizer': 'basic_english',時代に合わせた列車
124 'epochs': 32,エポックごとにトレーニングと検証を切り替える
127 'inner_iterations': 10,変圧器構成 (デフォルトと同じ)
130 'transformer.d_model': 512,
131 'transformer.ffn.d_ff': 2048,
132 'transformer.n_heads': 8,
133 'transformer.n_layers': 6,137 'transformer.encoder_attn': 'fnet_mix',保存および読み込み用のモデルを設定する
145 experiment.add_pytorch_models({'model': conf.model})実験を始める
148 with experiment.start():トレーニングを実行
150 conf.run()154if __name__ == '__main__':
155 main()