15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndexchunk_len
是区块长度chunks_per_sample
是每个训练样本的块数skip_range
是在两个样本之间跳过的最大字符数。我们在样本之间跳过几个字符,以确保样本与数据库中的块不完全对齐27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):加载文本文件
39 dataset = TextFileDataset(
40 lab.get_data_path() / 'tiny_shakespeare.txt',
41 list,
42 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')其中的一部分训练
45 text = dataset.train加载索引以检索邻居
48 index = RetroIndex()输入样本偏移
51 sample_offsets = []光标指向文本
53 i = 0
54 while i < len(text):跳过几个角色以确保它不与邻居对齐
56 skip = np.random.randint(skip_range)
57 i += skip如果我们到达了文字的末尾,就停下来
60 if i + chunks_per_sample * chunk_len > len(text):
61 break收集偏移量
64 sample_offsets.append(i)增加光标
67 i += chunks_per_sample * chunk_len对于样品
70 samples = []遍历样本偏移量
72 for i in monit.iterate('Gather Neighbors', sample_offsets):获取包含额外字符的样本(用于预测)
74 sample = text[i: i + chunks_per_sample * chunk_len + 1]输入
76 src = sample[:-1]把它分成大块
78 chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]区块偏移量
80 chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]检索最近的邻居
83 neighbor_offsets = index(chunks, chunk_offsets)获取邻居短信。邻居长度是两倍chunk_len
86 neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]添加到样品清单
89 samples.append((sample[:-1], sample[1:], neighbors))以 JSON 格式保存示例。我们不需要使用复杂的数据集存储机制或预标记化,因为我们的数据集很小。
94 with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95 f.write(json.dumps(samples))98class Dataset(PyTorchDataset):105 def __init__(self, file_path: Path, tds: TextDataset):111 self.tds = tds加载样品
113 with open(str(file_path), 'r') as f:
114 self.samples = json.loads(f.read())样本数量
116 def __len__(self):120 return len(self.samples)获取样品
122 def __getitem__(self, idx: int):获取样品
127 s = self.samples[idx]Tokenize
129 src = self.tds.text_to_i(s[0])
130 tgt = self.tds.text_to_i(s[1])
131 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])133 return src, tgt, neighbors136if __name__ == '__main__':
137 build_dataset()