これにより、自動回帰用の単純なトランスフォーマーモデルがトレーニングされます。位置ごとのフィードフォワードネットワークにはさまざまなバリエーションを試します
。labml.configs
これはモジュールを使わないより単純な実装です。慣れていない読者にもわかりやすいように、よりシンプルな実装を作成することにしました
19import dataclasses
20
21import torch
22from labml_helpers.module import Module
23from torch import nn
24from torch.utils.data import Dataset, DataLoader
25
26from labml import experiment, lab, tracker, monit, logger
27from labml.logger import Text
28from labml.utils.download import download_file
29from labml_nn.experiments.nlp_autoregression import transpose_batch
30from labml_nn.optimizers.noam import Noam
31from labml_nn.transformers import Encoder, MultiHeadAttention
32from labml_nn.transformers.feed_forward import FeedForward
33from labml_nn.transformers.models import EmbeddingsWithPositionalEncoding, TransformerLayer
34from labml_nn.transformers.utils import subsequent_mask37class AutoregressiveModel(Module):42    def __init__(self, src_embed: Module, encoder: Encoder, generator: Module):
43        super().__init__()トークン埋め込みモジュール
45        self.src_embed = src_embedトランスベースのエンコーダ
47        self.encoder = encoder次のトークン生成レイヤー。これにより、次のトークンのロジットが得られます
50        self.generator = generatorこれは最初の呼び出しで初期化されます。
52        self.src_mask = None54    def forward(self, src: torch.Tensor):次のマスクを作成して、トランスフォーマーが過去のトークンにしか注目できないようにします。
56        if self.src_mask is None or self.src_mask.size(0) != len(src):
57            self.src_mask = subsequent_mask(len(src)).to(src.device)トークン (src
) を埋め込み、トランスフォーマーに通します
59        res = self.encoder(self.src_embed(src), self.src_mask)次のトークンのロジットを生成
61        return self.generator(res)64@dataclasses.dataclass
65class Configs:69    d_model: int = 512
70    seq_len: int = 128
71    batch_size: int = 32
72    n_layers: int = 6
73    n_heads: int = 8
74    dropout: float = 0.1
75    d_ff: int = 2048
76    glu_variant: str = 'GLU'
77    epochs: int = 5
78    grad_norm_clip: float = 0.581class TinyShakespeareDataset(Dataset):86    def __init__(self, seq_len: int):テキストファイルの場所
88        path = lab.get_data_path() / 'tiny_shakespeare.txt'ファイルをダウンロードする
90        download_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', path)ダウンロードしたファイルを読み込む
92        with open(str(path), 'r') as f:
93            text = f.read()文字を抽出
96        chars = list(set(text))文字を ID (整数) にマッピング
98        self.stoi = {c: i for i, c in enumerate(chars)}ID を文字マップに変換
100        self.itos = {i: c for i, c in enumerate(chars)}トレーニングサンプルの長さ
102        self.seq_len = seq_lenid のテンソルの形式のデータ
104        self.data = self.text_to_i(text)テキストを id のテンソルに変換します
106    def text_to_i(self, text: str):110        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)112    def __len__(self):118        return len(self.data) - self.seq_len - 1サンプルを返す
120    def __getitem__(self, idx):124        return self.data[idx:idx + self.seq_len], self.data[idx + 1:idx + self.seq_len + 1]127class Trainer:132    def __init__(self, configs: Configs):デバイスを入手
134        self.device = torch.device('cpu')
135        if torch.cuda.is_available():
136            self.device = torch.device('cuda:0')データセットの初期化
138        self.dataset = TinyShakespeareDataset(configs.seq_len)データローダーの初期化
140        self.dataloader = DataLoader(self.dataset,
141                                     batch_size=configs.batch_size,
142                                     collate_fn=transpose_batch,
143                                     shuffle=True)ゲート付きリニアユニット付きFFN
147        if configs.glu_variant == 'GLU':
148            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Sigmoid(), True, False, False, False)バイリニア隠れ層付きFFN
151        elif configs.glu_variant == 'Bilinear':
152            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.Identity(), True, False, False, False)RelU ゲート付き FN
155        elif configs.glu_variant == 'ReGLU':
156            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU(), True, False, False, False)GELU ゲート付きFFN
159        elif configs.glu_variant == 'GEGLU':
160            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU(), True, False, False, False)スウィッシュゲート付きのFFNどこ
164        elif configs.glu_variant == 'SwiGLU':
165            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.SiLU(), True, False, False, False)RelU アクティベーションを使用した FFN
168        elif configs.glu_variant == 'ReLU':
169            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.ReLU())RelU アクティベーションを使用した FFN
172        elif configs.glu_variant == 'GELU':
173            ffn = FeedForward(configs.d_model, configs.d_ff, configs.dropout, nn.GELU())
174        else:
175            raise ValueError(f'Unknown variant {configs.glu_variant}')異なる文字の数
178        n_chars = len(self.dataset.stoi)181        mha = MultiHeadAttention(configs.n_heads, configs.d_model, configs.dropout)183        transformer_layer = TransformerLayer(d_model=configs.d_model, self_attn=mha, src_attn=None,
184                                             feed_forward=ffn, dropout_prob=configs.dropout)モデルを埋め込み層 (固定位置エンコーディング) トランスエンコーダーと線形層で初期化し、ロジットを生成します。
190        self.model = AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d_model, n_chars),
191                                         Encoder(transformer_layer, configs.n_layers),
192                                         nn.Linear(configs.d_model, n_chars))モデルを現在のデバイスに移動
195        self.model.to(self.device)198        self.optimizer = Noam(self.model.parameters(), lr=1.0, warmup=2_000, d_model=configs.d_model)クロスエントロピー損失
201        self.loss_func = nn.CrossEntropyLoss()トレーニングエポックの数。このデータセット定義では、seq_len
 1つのエポックでデータ回数が繰り返されることに注意してください
204        self.epochs = configs.epochsグラデーションクリッピングノルム
206        self.grad_norm_clip = configs.grad_norm_clipトラッカー構成を設定
209        tracker.set_scalar("loss.*", True)211    def sample(self):起動プロンプト
217        prompt = 'It is'印刷用の出力を収集
219        log = [(prompt, Text.subtle)]25トークンのサンプル
221        for i in monit.iterate('Sample', 25):プロンプトをトークン化
223            data = self.dataset.text_to_i(prompt).unsqueeze(-1)
224            data = data.to(self.device)モデル出力を取得
226            output = self.model(data)モデル予測を取得 (欲張り)
228            output = output.argmax(dim=-1).squeeze()予測をプロンプトに追加
230            prompt += self.dataset.itos[output[-1].item()]ロギング用の予測を追加
232            log += [(self.dataset.itos[output[-1].item()], Text.value)]サンプル出力を印刷する
235        logger.log(log)237    def train(self):指定されたエポック数のループ
243        for _ in monit.loop(self.epochs):ミニバッチを反復処理
245            for i, batch in monit.enum('Train', self.dataloader):データをデバイスに移動
247                data, target = batch[0].to(self.device), batch[1].to(self.device)トラッカーステップをトレーニングしたキャラクターの数として設定
250                tracker.add_global_step(data.shape[0] * data.shape[1])モデル状態をトレーニングに設定
253                self.model.train()モデルの評価
255                output = self.model(data)損失の計算
258                loss = self.loss_func(output.view(-1, output.shape[-1]), target.view(-1))損失を記録する
260                tracker.add("loss.train", loss)勾配の計算
263                loss.backward()クリップグラデーション
265                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.grad_norm_clip)最適化の一歩を踏み出す
267                self.optimizer.step()モデルパラメーターと勾配をログに記録します
269                if (i + 1) % 100 == 0:
270                    tracker.add('model', self.model)グラデーションをクリア
272                self.optimizer.zero_grad()サンプルを生成
275                if (i + 1) % 100 == 0:
276                    self.model.eval()
277                    with torch.no_grad():
278                        self.sample()追跡したメトリクスを保存する
281                if (i + 1) % 10 == 0:
282                    tracker.save()モデルを保存する
285            experiment.save_checkpoint()288def main():実験を作成
290    experiment.create(name="glu_variants")コンフィグの作成
292    configs = Configs()構成をロード
294    experiment.configs(dataclasses.asdict(configs))トレーナーを作成
297    trainer = Trainer(configs)トレーニングとロード用のモデルの設定
299    experiment.add_pytorch_models({'model': trainer.model})実験を始める
302    with experiment.start():モデルのトレーニング
304        trainer.train()
305
306
307if __name__ == '__main__':
308    main()