10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizer
path
is the location of the text file url
is the URL to download the file from filter_subset
is the number of characters to filter. Use this during testing when trying large datasets Returns the text content
23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34 path = Path(path)
Download if it doesn't exist
37 if not path.exists():
38 if not url:
39 raise FileNotFoundError(str(path))
40 else:
41 download_file(url, path)
42
43 with monit.section("Load data"):
Load data
45 with open(str(path), 'r') as f:
46 text = f.read()
Filter
48 if filter_subset:
49 text = text[:filter_subset]
52 return text
55class NeoXDataset(torch.utils.data.Dataset):
tokens
is the list of token ids seq_len
is the sequence length of a single training sample62 def __init__(self, tokens: List[int], seq_len: int):
68 self.seq_len = seq_len
Number of samples
70 n_samples = len(tokens) // seq_len
71 self.n_samples = n_samples
Truncate
73 tokens = tokens[:n_samples * seq_len + 1]
Create a PyTorch tensor
75 self.tokens = torch.tensor(tokens)
77 def __len__(self):
78 return self.n_samples
80 def __getitem__(self, idx: int):
87 offset = idx * self.seq_len
88 return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92 'tiny_shakespeare': {
93 'file': 'tiny_shakespeare.txt',
94 'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95 }
96}
seq_len
is the sequence length of a single training sample dataset_name
is the name of the dataset Returns the dataset
99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108 ds = DATASETS[dataset_name]
Load the content
110 text = load_text(lab.get_data_path() / ds['file'], ds['url'])
Tokenize
112 tokenizer = get_tokenizer()
113 tokens = tokenizer.encode_batch([text])[0]
114
115 if truncate > 0:
116 token_ids = tokens.ids[:truncate * seq_len]
117 else:
118 token_ids = tokens.ids
121 return NeoXDataset(token_ids, seq_len)
124def _test():
125 dataset = get_training_data()
126
127 inspect(tokens=len(dataset.tokens))
131if __name__ == '__main__':
132 _test()