15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndexchunk_len
はチャンクの長さですchunks_per_sample
はトレーニングサンプルあたりのチャンク数ですskip_range
2 つのサンプル間でスキップする最大文字数です。27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):テキストファイルを読み込む
39    dataset = TextFileDataset(
40        lab.get_data_path() / 'tiny_shakespeare.txt',
41        list,
42        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')そのトレーニング部分
45    text = dataset.train近傍検索用のインデックスを読み込む
48    index = RetroIndex()入力サンプルオフセット
51    sample_offsets = []テキスト用のカーソル
53    i = 0
54    while i < len(text):数文字飛ばして、隣の文字と揃わないようにしてください
56        skip = np.random.randint(skip_range)
57        i += skipテキストの終わりに達したら止めてください
60        if i + chunks_per_sample * chunk_len > len(text):
61            breakオフセットの収集
64        sample_offsets.append(i)カーソルをインクリメントしてください
67        i += chunks_per_sample * chunk_lenサンプル用
70    samples = []サンプルオフセットを反復処理
72    for i in monit.iterate('Gather Neighbors', sample_offsets):追加文字を含むサンプルを取得 (予測用)
74        sample = text[i: i + chunks_per_sample * chunk_len + 1]インプット
76        src = sample[:-1]それをチャンクに分割してください
78        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]チャンクオフセット
80        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]最も近い近傍を検索する
83        neighbor_offsets = index(chunks, chunk_offsets)86        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]サンプルリストに追加
89        samples.append((sample[:-1], sample[1:], neighbors))サンプルを JSON で保存します。データセットは小さいため、複雑なデータセットの保存メカニズムを使用したり、事前にトークン化したりする必要はありません
。94    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95        f.write(json.dumps(samples))98class Dataset(PyTorchDataset):105    def __init__(self, file_path: Path, tds: TextDataset):111        self.tds = tdsサンプルを読み込む
113        with open(str(file_path), 'r') as f:
114            self.samples = json.loads(f.read())サンプル数
116    def __len__(self):120        return len(self.samples)サンプルを入手
122    def __getitem__(self, idx: int):サンプルを入手
127        s = self.samples[idx]トークン化
129        src = self.tds.text_to_i(s[0])
130        tgt = self.tds.text_to_i(s[1])
131        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])133        return src, tgt, neighbors136if __name__ == '__main__':
137    build_dataset()