15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex
chunk_len
はチャンクの長さですchunks_per_sample
はトレーニングサンプルあたりのチャンク数ですskip_range
2 つのサンプル間でスキップする最大文字数です。27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):
テキストファイルを読み込む
39 dataset = TextFileDataset(
40 lab.get_data_path() / 'tiny_shakespeare.txt',
41 list,
42 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
そのトレーニング部分
45 text = dataset.train
近傍検索用のインデックスを読み込む
48 index = RetroIndex()
入力サンプルオフセット
51 sample_offsets = []
テキスト用のカーソル
53 i = 0
54 while i < len(text):
数文字飛ばして、隣の文字と揃わないようにしてください
56 skip = np.random.randint(skip_range)
57 i += skip
テキストの終わりに達したら止めてください
60 if i + chunks_per_sample * chunk_len > len(text):
61 break
オフセットの収集
64 sample_offsets.append(i)
カーソルをインクリメントしてください
67 i += chunks_per_sample * chunk_len
サンプル用
70 samples = []
サンプルオフセットを反復処理
72 for i in monit.iterate('Gather Neighbors', sample_offsets):
追加文字を含むサンプルを取得 (予測用)
74 sample = text[i: i + chunks_per_sample * chunk_len + 1]
インプット
76 src = sample[:-1]
それをチャンクに分割してください
78 chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]
チャンクオフセット
80 chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]
最も近い近傍を検索する
83 neighbor_offsets = index(chunks, chunk_offsets)
86 neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]
サンプルリストに追加
89 samples.append((sample[:-1], sample[1:], neighbors))
サンプルを JSON で保存します。データセットは小さいため、複雑なデータセットの保存メカニズムを使用したり、事前にトークン化したりする必要はありません
。94 with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95 f.write(json.dumps(samples))
98class Dataset(PyTorchDataset):
105 def __init__(self, file_path: Path, tds: TextDataset):
111 self.tds = tds
サンプルを読み込む
113 with open(str(file_path), 'r') as f:
114 self.samples = json.loads(f.read())
サンプル数
116 def __len__(self):
120 return len(self.samples)
サンプルを入手
122 def __getitem__(self, idx: int):
サンプルを入手
127 s = self.samples[idx]
トークン化
129 src = self.tds.text_to_i(s[0])
130 tgt = self.tds.text_to_i(s[1])
131 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])
133 return src, tgt, neighbors
136if __name__ == '__main__':
137 build_dataset()