RETRO 训练数据集

我们从键值数据库中预先检索最近的邻域,并创建数据集来训练 RETRO 模型

15import json
16from pathlib import Path
17
18import numpy as np
19import torch
20from torch.utils.data import Dataset as PyTorchDataset
21
22from labml import lab, monit
23from labml_helpers.datasets.text import TextFileDataset, TextDataset
24from labml_nn.transformers.retro.database import RetroIndex

构建数据集

  • chunk_len 是区块长度
  • chunks_per_sample 是每个训练样本的块数
  • skip_range 是在两个样本之间跳过的最大字符数。我们在样本之间跳过几个字符,以确保样本与数据库中的块不完全对齐
27def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):

加载文本文件

39    dataset = TextFileDataset(
40        lab.get_data_path() / 'tiny_shakespeare.txt',
41        list,
42        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

其中的一部分训练

45    text = dataset.train

加载索引以检索邻居

48    index = RetroIndex()

输入样本偏移

51    sample_offsets = []

光标指向文本

53    i = 0
54    while i < len(text):

跳过几个角色以确保它不与邻居对齐

56        skip = np.random.randint(skip_range)
57        i += skip

如果我们到达了文字的末尾,就停下来

60        if i + chunks_per_sample * chunk_len > len(text):
61            break

收集偏移量

64        sample_offsets.append(i)

增加光标

67        i += chunks_per_sample * chunk_len

对于样品

70    samples = []

遍历样本偏移量

72    for i in monit.iterate('Gather Neighbors', sample_offsets):

获取包含额外字符的样本(用于预测)

74        sample = text[i: i + chunks_per_sample * chunk_len + 1]

输入

76        src = sample[:-1]

把它分成大块

78        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]

区块偏移量

80        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

检索最近的邻居

83        neighbor_offsets = index(chunks, chunk_offsets)

获取邻居短信。邻居长度是两倍chunk_len

86        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

添加到样品清单

89        samples.append((sample[:-1], sample[1:], neighbors))

以 JSON 格式保存示例。我们不需要使用复杂的数据集存储机制或预标记化,因为我们的数据集很小。

94    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
95        f.write(json.dumps(samples))

数据集

这是 PyTorch 数据集,用于加载由创建的数据集build_dataset

98class Dataset(PyTorchDataset):
  • file_path 是保存的 JSON 文件的路径
  • tdsTextDataset
  • 105    def __init__(self, file_path: Path, tds: TextDataset):
    111        self.tds = tds

    加载样品

    113        with open(str(file_path), 'r') as f:
    114            self.samples = json.loads(f.read())

    样本数量

    116    def __len__(self):
    120        return len(self.samples)

    获取样品

    122    def __getitem__(self, idx: int):

    获取样品

    127        s = self.samples[idx]

    Tokenize

    129        src = self.tds.text_to_i(s[0])
    130        tgt = self.tds.text_to_i(s[1])
    131        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])

    133        return src, tgt, neighbors

    136if __name__ == '__main__':
    137    build_dataset()