GPT-NEOX 的文本数据集

10from pathlib import PurePath, Path
11from typing import Optional, List
12
13import torch
14import torch.utils.data
15from labml import lab
16from labml import monit
17from labml.logger import inspect
18from labml.utils.download import download_file
19
20from labml_nn.neox.tokenizer import get_tokenizer

加载文本文件

  • path 是文本文件的位置
  • url 是从中下载文件的 URL
  • filter_subset 是要筛选的字符数。在测试期间尝试大型数据集时使用此
  • 选项

    返回文本内容

23def load_text(path: PurePath, url: Optional[str] = None, *, filter_subset: Optional[int] = None):
34    path = Path(path)

如果不存在,请下载

37    if not path.exists():
38        if not url:
39            raise FileNotFoundError(str(path))
40        else:
41            download_file(url, path)
42
43    with monit.section("Load data"):

加载数据

45        with open(str(path), 'r') as f:
46            text = f.read()

筛选

48        if filter_subset:
49            text = text[:filter_subset]

52    return text

用于微调 GPT-NEOX 的数据集

这并未针对非常大的数据集进行优化。

55class NeoXDataset(torch.utils.data.Dataset):
  • tokens 是令牌 ID 的列表
  • seq_len 是单个训练样本的序列长度
62    def __init__(self, tokens: List[int], seq_len: int):
68        self.seq_len = seq_len

样本数量

70        n_samples = len(tokens) // seq_len
71        self.n_samples = n_samples

截断

73        tokens = tokens[:n_samples * seq_len + 1]

创建一个 pyTorch 张量

75        self.tokens = torch.tensor(tokens)
77    def __len__(self):
78        return self.n_samples

获取样品

  • idx 是样本的索引
  • 返回输入和目标

80    def __getitem__(self, idx: int):
87        offset = idx * self.seq_len
88        return self.tokens[offset:offset + self.seq_len], self.tokens[offset + 1:offset + 1 + self.seq_len]
89
90
91DATASETS = {
92    'tiny_shakespeare': {
93        'file': 'tiny_shakespeare.txt',
94        'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
95    }
96}

加载数据集

  • seq_len 是单个训练样本的序列长度
  • dataset_name 是数据集的名称
  • 返回数据集

99def get_training_data(seq_len: int = 32, dataset_name: str = 'tiny_shakespeare', truncate: int = -1):
108    ds = DATASETS[dataset_name]

加载内容

110    text = load_text(lab.get_data_path() / ds['file'], ds['url'])

Tokenize

112    tokenizer = get_tokenizer()
113    tokens = tokenizer.encode_batch([text])[0]
114
115    if truncate > 0:
116        token_ids = tokens.ids[:truncate * seq_len]
117    else:
118        token_ids = tokens.ids

121    return NeoXDataset(token_ids, seq_len)
124def _test():
125    dataset = get_training_data()
126
127    inspect(tokens=len(dataset.tokens))

131if __name__ == '__main__':
132    _test()