14import torch
15from torch import nn
16from torch.utils.data import DataLoader, RandomSampler
17
18from labml import monit, lab, tracker, experiment, logger
19from labml.logger import Text
20from labml_helpers.datasets.text import TextFileDataset
21from labml_nn.optimizers.noam import Noam
22from labml_nn.transformers.retro import model as retro
23from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
24from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder
27class Sampler:
34 def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
41 self.chunk_len = chunk_len
42 self.tds = tds
43 self.model = model
44 self.device = device
49 def retrieve_nearest_neighbours(self, chunk: str):
最も近い近傍のオフセットを取得
55 neighbor_offsets = self.index([chunk], None)
近傍を取得 (近傍の長さがと等しい) chunk_len * 2
58 text = self.tds.train
59 neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]
62 return neighbors
64 def sample(self, prompt: str, sample_len: int):
最も近い近傍を文字列として保存するには
70 neighbors_str = []
サンプルテキスト
73 sampled = ''
sample_len
サンプルトークン
76 for i in range(sample_len):
すでに取得したチャンクの数よりもサンプリングされたチャンクの数が多い場合は、隣接データを取得する必要があります
79 while len(neighbors_str) < len(prompt) // self.chunk_len:
隣接データを取得していない最後のチャンクを取得
81 off = len(neighbors_str) * self.chunk_len
82 chunk = prompt[off: off + self.chunk_len]
最も近い近傍を検索する
84 neighbors_str.append(self.retrieve_nearest_neighbours(chunk))
入力をトークン化
87 src = self.tds.text_to_i(prompt)
取得したネイバーをトークン化する
89 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])
モデルと同じデバイスに移動します
92 src = src.to(self.device)
93 neighbors = neighbors.to(self.device)
モデル出力を取得
96 res = self.model(src[None, :], neighbors[None, :, :, :])
最後のトークンを欲張ってサンプリングする
99 token = res[0, -1, :].argmax(dim=-1)
サンプリングしたトークンのテキストをプロンプトとサンプルテキストに追加します
102 prompt += self.tds.itos[token.item()]
103 sampled += self.tds.itos[token.item()]
106 return sampled
109class Trainer:
device
モデルのデバイスですmodel
レトロモードですdataloader
は、事前に近傍が取得されたデータセットのデータローダーoptimizer
オプティマイザーです114 def __init__(self, device: torch.device, model: retro.RetroModel,
115 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
122 self.optimizer = optimizer
123 self.device = device
124 self.dataloader = dataloader
125 self.model = model
126 self.loss_func = nn.CrossEntropyLoss()
128 def __call__(self):
トレーニングデータを繰り返し処理
134 for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):
データをデバイスに移動
136 src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)
フォワードパス
139 res = self.model(src, neighbors)
損失の計算
141 loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))
グラデーションをクリア
144 self.optimizer.zero_grad()
バックワードパス
146 loss.backward()
モデルを最適化
148 self.optimizer.step()
トレーニング統計を保存してグローバルステップカウンタを増やす
151 tracker.save({'loss.train': loss})
152 tracker.add_global_step(len(src))
155def train():
テストを作成
161 experiment.create(name='retro_small')
GPU デバイス
164 device = torch.device('cuda:0')
タイニーシェイクスピアデータセットの読み込み
167 tds = TextFileDataset(
168 lab.get_data_path() / 'tiny_shakespeare.txt',
169 list,
170 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
173 train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)
データローダーの作成
176 train_dl = DataLoader(train_dataset,
177 batch_size=4,
178 sampler=RandomSampler(train_dataset, replacement=True))
ハイパーパラメータ
181 chunk_len = 16
182 d_model = 128
183 d_ff = 512
184 n_heads = 16
185 d_k = 16
最近傍エンコーダーの作成
188 nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
モデルを作成
190 model = RetroModel(tds.n_tokens, d_model, 6,
191 {3, 5},
192 chunk_len, n_heads, d_k, d_ff,
193 encoder=nearest_neighbor_encoder)
モデルをデバイスに移動
195 model = model.to(device)
オプティマイザーの作成
197 optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
作成 Trainer
199 trainer = Trainer(device, model, train_dl, optimizer)
作成 Sampler
201 sampler = Sampler(device, model, tds, chunk_len)
203 prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''
保存および読み込み用のモデルを設定する
206 experiment.add_pytorch_models(model=model)
実験を始める
209 with experiment.start():
32
時代に合わせた列車
211 for epoch in monit.loop(32):
列車
213 trainer()
新しい行を印刷
215 tracker.new_line()
からのサンプル prompt
217 logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
218 (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
モデルを保存する
220 experiment.save_checkpoint()
224if __name__ == '__main__':
225 train()