奇偶校验任务

这将从论文《循环神经网络的自适应计算时间》中为奇偶校验任务创建数据。

奇偶校验任务的输入是一个带有's 和's 的向量。输出是's 的奇偶校验——如果有,则为 1是的奇数,否则为零。输入是通过使矢量中的随机数量的元素为或而生成的。

19from typing import Tuple
20
21import torch
22from torch.utils.data import Dataset

奇偶校验数据

25class ParityDataset(Dataset):
  • n_samples 是样本的数量
  • n_elems 是输入向量中的元素数
30    def __init__(self, n_samples: int, n_elems: int = 64):
35        self.n_samples = n_samples
36        self.n_elems = n_elems

数据集的大小

38    def __len__(self):
42        return self.n_samples

生成样本

44    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:

空向量

50        x = torch.zeros((self.n_elems,))

非零元素的数量-介于和元素总数之间的随机数

52        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()

“和” 填充非零元素

54        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1

随机排列元素

56        x = x[torch.randperm(self.n_elems)]

奇偶校验

59        y = (x == 1.).sum() % 2

62        return x, y