Text Dataset for GPT-NeoX

10from pathlib import PurePath, Path
11from typing import Optional, List
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
20from labml_nn.neox.tokenizer import get_tokenizer

Load text file

  • path is the location of the text file
  • url is the URL to download the file from
  • filter_subset is the number of characters to filter. Use this during testing when trying large datasets
  • Returns the text content

23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34    path = Path(path)

Download if it doesn't exist

37    if not path.exists():
38        if not url:
39            raise FileNotFoundError(str(path))
40        else:
41            download_file(url, path)
43    with monit.section("Load data"):

Load data

45        with open(str(path), 'r') as f:
46            text = f.read()


48        if filter_subset:
49            text = text[:filter_subset]

52    return text

Dataset for fine-tuning GPT-NeoX

This is not optimized to very large datasets.

55class NeoXDataset(torch.utils.data.Dataset):
  • tokens is the list of token ids
  • seq_len is the sequence length of a single training sample
62    def __init__(self, tokens: List[int], seq_len: int):
68        self.seq_len = seq_len

Number of samples

70        n_samples = len(tokens) // seq_len
71        self.n_samples = n_samples


73        tokens = tokens[:n_samples * seq_len + 1]

Create a PyTorch tensor

75        self.tokens = torch.tensor(tokens)
77    def __len__(self):
78        return self.n_samples

Get a sample

  • idx is the index of the sample
  • Returns the input and the target

80    def __getitem__(self, idx: int):
87        offset = idx * self.seq_len
88        return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
92    'tiny_shakespeare': {
93        'file': 'tiny_shakespeare.txt',
94        'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95    }

Load Dataset

  • seq_len is the sequence length of a single training sample
  • dataset_name is the name of the dataset
  • Returns the dataset

99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108    ds = DATASETS[dataset_name]

Load the content

110    text = load_text(lab.get_data_path() / ds['file'], ds['url'])


112    tokenizer = get_tokenizer()
113    tokens = tokenizer.encode_batch([text])[0]
115    if truncate > 0:
116        token_ids = tokens.ids[:truncate * seq_len]
117    else:
118        token_ids = tokens.ids

121    return NeoXDataset(token_ids, seq_len)
124def _test():
125    dataset = get_training_data()
127    inspect(tokens=len(dataset.tokens))

131if __name__ == '__main__':
132    _test()