10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizer
23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34 path = Path(path)
如果不存在,请下载
37 if not path.exists():
38 if not url:
39 raise FileNotFoundError(str(path))
40 else:
41 download_file(url, path)
42
43 with monit.section("Load data"):
加载数据
45 with open(str(path), 'r') as f:
46 text = f.read()
筛选
48 if filter_subset:
49 text = text[:filter_subset]
52 return text
55class NeoXDataset(torch.utils.data.Dataset):
tokens
是令牌 ID 的列表seq_len
是单个训练样本的序列长度62 def __init__(self, tokens: List[int], seq_len: int):
68 self.seq_len = seq_len
样本数量
70 n_samples = len(tokens) // seq_len
71 self.n_samples = n_samples
截断
73 tokens = tokens[:n_samples * seq_len + 1]
创建一个 pyTorch 张量
75 self.tokens = torch.tensor(tokens)
77 def __len__(self):
78 return self.n_samples
80 def __getitem__(self, idx: int):
87 offset = idx * self.seq_len
88 return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92 'tiny_shakespeare': {
93 'file': 'tiny_shakespeare.txt',
94 'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95 }
96}
99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108 ds = DATASETS[dataset_name]
加载内容
110 text = load_text(lab.get_data_path() / ds['file'], ds['url'])
Tokenize
112 tokenizer = get_tokenizer()
113 tokens = tokenizer.encode_batch([text])[0]
114
115 if truncate > 0:
116 token_ids = tokens.ids[:truncate * seq_len]
117 else:
118 token_ids = tokens.ids
121 return NeoXDataset(token_ids, seq_len)
124def _test():
125 dataset = get_training_data()
126
127 inspect(tokens=len(dataset.tokens))
131if __name__ == '__main__':
132 _test()