GPT-Neox 用テキストデータセット

10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizer

テキストファイルの読み込み

  • path テキストファイルの場所です
  • url ファイルをダウンロードする URL
  • filter_subset フィルタリングする文字の数です。大規模なデータセットを試すときのテスト時にこれを使用してください
  • テキストコンテンツを返します

23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34    path = Path(path)

存在しない場合はダウンロード

37    if not path.exists():
38        if not url:
39            raise FileNotFoundError(str(path))
40        else:
41            download_file(url, path)
42
43    with monit.section("Load data"):

データを読み込む

45        with open(str(path), 'r') as f:
46            text = f.read()

フィルタ

48        if filter_subset:
49            text = text[:filter_subset]

52    return text

GPT-Neox を微調整するためのデータセット

これは非常に大きなデータセットには最適化されていません。

55class NeoXDataset(torch.utils.data.Dataset):
  • tokens トークン ID のリストです
  • seq_len は 1 つのトレーニングサンプルのシーケンス長です
62    def __init__(self, tokens: List[int], seq_len: int):
68        self.seq_len = seq_len

サンプル数

70        n_samples = len(tokens) // seq_len
71        self.n_samples = n_samples

切り捨て

73        tokens = tokens[:n_samples * seq_len + 1]

PyTorch テンソルの作成

75        self.tokens = torch.tensor(tokens)
77    def __len__(self):
78        return self.n_samples

サンプルを入手

  • idx サンプルのインデックスです
  • 入力とターゲットを返します

80    def __getitem__(self, idx: int):
87        offset = idx * self.seq_len
88        return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92    'tiny_shakespeare': {
93        'file': 'tiny_shakespeare.txt',
94        'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95    }
96}

データセットの読み込み

  • seq_len は 1 つのトレーニングサンプルのシーケンス長です
  • dataset_name はデータセットの名前
  • データセットを返します

99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108    ds = DATASETS[dataset_name]

コンテンツを読み込む

110    text = load_text(lab.get_data_path() / ds['file'], ds['url'])

トークン化

112    tokenizer = get_tokenizer()
113    tokens = tokenizer.encode_batch([text])[0]
114
115    if truncate > 0:
116        token_ids = tokens.ids[:truncate * seq_len]
117    else:
118        token_ids = tokens.ids

121    return NeoXDataset(token_ids, seq_len)
124def _test():
125    dataset = get_training_data()
126
127    inspect(tokens=len(dataset.tokens))

131if __name__ == '__main__':
132    _test()