10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizer
path
テキストファイルの場所ですurl
ファイルをダウンロードする URLfilter_subset
フィルタリングする文字の数です。大規模なデータセットを試すときのテスト時にこれを使用してくださいテキストコンテンツを返します
23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34 path = Path(path)
存在しない場合はダウンロード
37 if not path.exists():
38 if not url:
39 raise FileNotFoundError(str(path))
40 else:
41 download_file(url, path)
42
43 with monit.section("Load data"):
データを読み込む
45 with open(str(path), 'r') as f:
46 text = f.read()
フィルタ
48 if filter_subset:
49 text = text[:filter_subset]
52 return text
55class NeoXDataset(torch.utils.data.Dataset):
tokens
トークン ID のリストですseq_len
は 1 つのトレーニングサンプルのシーケンス長です62 def __init__(self, tokens: List[int], seq_len: int):
68 self.seq_len = seq_len
サンプル数
70 n_samples = len(tokens) // seq_len
71 self.n_samples = n_samples
切り捨て
73 tokens = tokens[:n_samples * seq_len + 1]
PyTorch テンソルの作成
75 self.tokens = torch.tensor(tokens)
77 def __len__(self):
78 return self.n_samples
80 def __getitem__(self, idx: int):
87 offset = idx * self.seq_len
88 return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92 'tiny_shakespeare': {
93 'file': 'tiny_shakespeare.txt',
94 'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95 }
96}
99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108 ds = DATASETS[dataset_name]
コンテンツを読み込む
110 text = load_text(lab.get_data_path() / ds['file'], ds['url'])
トークン化
112 tokenizer = get_tokenizer()
113 tokens = tokenizer.encode_batch([text])[0]
114
115 if truncate > 0:
116 token_ids = tokens.ids[:truncate * seq_len]
117 else:
118 token_ids = tokens.ids
121 return NeoXDataset(token_ids, seq_len)
124def _test():
125 dataset = get_training_data()
126
127 inspect(tokens=len(dataset.tokens))
131if __name__ == '__main__':
132 _test()