15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex

データセットの構築

  • chunk_len はチャンクの長さです
  • chunks_per_sample はトレーニングサンプルあたりのチャンク数です
  • skip_range 2 つのサンプル間でスキップする最大文字数です。
サンプルがデータベース内のチャンクと完全に一致していないことを確認するために、サンプル間を数文字スキップしています。
27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):

テキストファイルを読み込む

39    dataset = TextFileDataset(
40        lab.get_data_path() / 'tiny_shakespeare.txt',
41        list,
42        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

そのトレーニング部分

45    text = dataset.train

近傍検索用のインデックスを読み込む

48    index = RetroIndex()

入力サンプルオフセット

51    sample_offsets = []

テキスト用のカーソル

53    i = 0
54    while i < len(text):

数文字飛ばして、隣の文字と揃わないようにしてください

56        skip = np.random.randint(skip_range)
57        i += skip

テキストの終わりに達したら止めてください

60        if i + chunks_per_sample * chunk_len > len(text):
61            break

オフセットの収集

64        sample_offsets.append(i)

カーソルをインクリメントしてください

67        i += chunks_per_sample * chunk_len

サンプル用

70    samples = []

サンプルオフセットを反復処理

72    for i in monit.iterate('Gather Neighbors', sample_offsets):

追加文字を含むサンプルを取得 (予測用)

74        sample = text[i: i + chunks_per_sample * chunk_len + 1]

インプット

76        src = sample[:-1]

それをチャンクに分割してください

78        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]

チャンクオフセット

80        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

最も近い近傍を検索する

83        neighbor_offsets = index(chunks, chunk_offsets)

近所の人のテキストを取得.近傍の長さは 2 倍です

chunk_len
86        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

サンプルリストに追加

89        samples.append((sample[:-1], sample[1:], neighbors))

サンプルを JSON で保存します。データセットは小さいため、複雑なデータセットの保存メカニズムを使用したり、事前にトークン化したりする必要はありません

94    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95        f.write(json.dumps(samples))

データセット

これは、によって作成されたデータセットをロードする PyTorch データセットです。build_dataset

98class Dataset(PyTorchDataset):
  • file_path 保存された JSON ファイルのパスです
  • tdsTextDataset
  • 105    def __init__(self, file_path: Path, tds: TextDataset):
    111        self.tds = tds

    サンプルを読み込む

    113        with open(str(file_path), 'r') as f:
    114            self.samples = json.loads(f.read())

    サンプル数

    116    def __len__(self):
    120        return len(self.samples)

    サンプルを入手

    122    def __getitem__(self, idx: int):

    サンプルを入手

    127        s = self.samples[idx]

    トークン化

    129        src = self.tds.text_to_i(s[0])
    130        tgt = self.tds.text_to_i(s[1])
    131        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])

    133        return src, tgt, neighbors

    136if __name__ == '__main__':
    137    build_dataset()