15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex
chunk_len
是区块长度chunks_per_sample
是每个训练样本的块数skip_range
是在两个样本之间跳过的最大字符数。我们在样本之间跳过几个字符,以确保样本与数据库中的块不完全对齐27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):
加载文本文件
39 dataset = TextFileDataset(
40 lab.get_data_path() / 'tiny_shakespeare.txt',
41 list,
42 url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
其中的一部分训练
45 text = dataset.train
加载索引以检索邻居
48 index = RetroIndex()
输入样本偏移
51 sample_offsets = []
光标指向文本
53 i = 0
54 while i < len(text):
跳过几个角色以确保它不与邻居对齐
56 skip = np.random.randint(skip_range)
57 i += skip
如果我们到达了文字的末尾,就停下来
60 if i + chunks_per_sample * chunk_len > len(text):
61 break
收集偏移量
64 sample_offsets.append(i)
增加光标
67 i += chunks_per_sample * chunk_len
对于样品
70 samples = []
遍历样本偏移量
72 for i in monit.iterate('Gather Neighbors', sample_offsets):
获取包含额外字符的样本(用于预测)
74 sample = text[i: i + chunks_per_sample * chunk_len + 1]
输入
76 src = sample[:-1]
把它分成大块
78 chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]
区块偏移量
80 chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]
检索最近的邻居
83 neighbor_offsets = index(chunks, chunk_offsets)
获取邻居短信。邻居长度是两倍chunk_len
86 neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]
添加到样品清单
89 samples.append((sample[:-1], sample[1:], neighbors))
以 JSON 格式保存示例。我们不需要使用复杂的数据集存储机制或预标记化,因为我们的数据集很小。
94 with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95 f.write(json.dumps(samples))
98class Dataset(PyTorchDataset):
105 def __init__(self, file_path: Path, tds: TextDataset):
111 self.tds = tds
加载样品
113 with open(str(file_path), 'r') as f:
114 self.samples = json.loads(f.read())
样本数量
116 def __len__(self):
120 return len(self.samples)
获取样品
122 def __getitem__(self, idx: int):
获取样品
127 s = self.samples[idx]
Tokenize
129 src = self.tds.text_to_i(s[0])
130 tgt = self.tds.text_to_i(s[1])
131 neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])
133 return src, tgt, neighbors
136if __name__ == '__main__':
137 build_dataset()