パリティタスク

これにより、論文「リカレントニューラルネットワークの適応的計算時間」からパリティタスクのデータが作成されます。

パリティタスクの入力は、と s の付いたベクトルで、出力は s のパリティです。s の数が奇数の場合は 1、それ以外の場合は 0 です。入力は、ベクトル内のランダムな数の要素をまたはのいずれかにすることによって生成されます。

19from typing import Tuple
20
21import torch
22from torch.utils.data import Dataset

パリティデータセット

25class ParityDataset(Dataset):
  • n_samples はサンプル数
  • n_elems は入力ベクトルの要素数です
30    def __init__(self, n_samples: int, n_elems: int = 64):
35        self.n_samples = n_samples
36        self.n_elems = n_elems

データセットのサイズ

38    def __len__(self):
42        return self.n_samples

サンプルを生成

44    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:

空のベクトル

50        x = torch.zeros((self.n_elems,))

ゼロ以外の要素の数- 要素の数と要素数の間のランダムな数

52        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()

0 以外の要素を「」と 「」で埋める

54        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1

要素をランダムに並べ替える

56        x = x[torch.randperm(self.n_elems)]

パリティ

59        y = (x == 1.).sum() % 2

62        return x, y