レトロトレーニング

これはRETROのトレーニングコードです

14import torch
15from torch import nn
16from torch.utils.data import DataLoader, RandomSampler
17
18from labml import monit, lab, tracker, experiment, logger
19from labml.logger import Text
20from labml_helpers.datasets.text import TextFileDataset
21from labml_nn.optimizers.noam import Noam
22from labml_nn.transformers.retro import model as retro
23from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
24from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder

サンプラー

このクラスはモデルから貪欲にサンプリングします。

27class Sampler:
  • device モデルのデバイスです
  • model レトロモードです
  • tds テキストデータセット (隣接チャンクの取得に使用)
  • chunk_len チャンクの長さです
34    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
41        self.chunk_len = chunk_len
42        self.tds = tds
43        self.model = model
44        self.device = device
47        self.index = RetroIndex()

与えられたチャンクの最も近い近傍のデータを取得

49    def retrieve_nearest_neighbours(self, chunk: str):

最も近い近傍のオフセットを取得

55        neighbor_offsets = self.index([chunk], None)

近傍を取得 (近傍の長さがと等しい) chunk_len * 2

58        text = self.tds.train
59        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

62        return neighbors

与えられたプロンプトのサンプルテキスト

64    def sample(self, prompt: str, sample_len: int):

最も近い近傍を文字列として保存するには

70        neighbors_str = []

サンプルテキスト

73        sampled = ''

sample_len サンプルトークン

76        for i in range(sample_len):

すでに取得したチャンクの数よりもサンプリングされたチャンクの数が多い場合は、隣接データを取得する必要があります

79            while len(neighbors_str) < len(prompt) // self.chunk_len:

隣接データを取得していない最後のチャンクを取得

81                off = len(neighbors_str) * self.chunk_len
82                chunk = prompt[off: off + self.chunk_len]

最も近い近傍を検索する

84                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

入力をトークン化

87            src = self.tds.text_to_i(prompt)

取得したネイバーをトークン化する

89            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

モデルと同じデバイスに移動します

92            src = src.to(self.device)
93            neighbors = neighbors.to(self.device)

モデル出力を取得

96            res = self.model(src[None, :], neighbors[None, :, :, :])

最後のトークンを欲張ってサンプリングする

99            token = res[0, -1, :].argmax(dim=-1)

サンプリングしたトークンのテキストをプロンプトとサンプルテキストに追加します

102            prompt += self.tds.itos[token.item()]
103            sampled += self.tds.itos[token.item()]

106        return sampled

レトロトレーナー

109class Trainer:
114    def __init__(self, device: torch.device, model: retro.RetroModel,
115                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
122        self.optimizer = optimizer
123        self.device = device
124        self.dataloader = dataloader
125        self.model = model
126        self.loss_func = nn.CrossEntropyLoss()

一時代を先取りしたモデルのトレーニング

128    def __call__(self):

トレーニングデータを繰り返し処理

134        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):

データをデバイスに移動

136            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

フォワードパス

139            res = self.model(src, neighbors)

損失の計算

141            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

グラデーションをクリア

144            self.optimizer.zero_grad()

バックワードパス

146            loss.backward()

モデルを最適化

148            self.optimizer.step()

トレーニング統計を保存してグローバルステップカウンタを増やす

151            tracker.save({'loss.train': loss})
152            tracker.add_global_step(len(src))

小型モデルの作成とトレーニング

155def train():

テストを作成

161    experiment.create(name='retro_small')

GPU デバイス

164    device = torch.device('cuda:0')

タイニーシェイクスピアデータセットの読み込み

167    tds = TextFileDataset(
168        lab.get_data_path() / 'tiny_shakespeare.txt',
169        list,
170        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
173    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

データローダーの作成

176    train_dl = DataLoader(train_dataset,
177                          batch_size=4,
178                          sampler=RandomSampler(train_dataset, replacement=True))

ハイパーパラメータ

181    chunk_len = 16
182    d_model = 128
183    d_ff = 512
184    n_heads = 16
185    d_k = 16

最近傍エンコーダーの作成

188    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)

モデルを作成

190    model = RetroModel(tds.n_tokens, d_model, 6,
191                       {3, 5},
192                       chunk_len, n_heads, d_k, d_ff,
193                       encoder=nearest_neighbor_encoder)

モデルをデバイスに移動

195    model = model.to(device)

オプティマイザーの作成

197    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)

作成 Trainer

199    trainer = Trainer(device, model, train_dl, optimizer)

作成 Sampler

201    sampler = Sampler(device, model, tds, chunk_len)

203    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

保存および読み込み用のモデルを設定する

206    experiment.add_pytorch_models(model=model)

実験を始める

209    with experiment.start():

32 時代に合わせた列車

211        for epoch in monit.loop(32):

列車

213            trainer()

新しい行を印刷

215            tracker.new_line()

からのサンプル prompt

217            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
218                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])

モデルを保存する

220            experiment.save_checkpoint()

224if __name__ == '__main__':
225    train()