Parity Task

This creates data for Parity Task from the paper Adaptive Computation Time for Recurrent Neural Networks.

The input of the parity task is a vector with 's 's and 's. The output is the parity of 's - one if there is an odd number of 's and zero otherwise. The input is generated by making a random number of elements in the vector either or 's.

19from typing import Tuple
20
21import torch
22from torch.utils.data import Dataset

Parity dataset

25class ParityDataset(Dataset):
  • n_samples is the number of samples
  • n_elems is the number of elements in the input vector
30    def __init__(self, n_samples: int, n_elems: int = 64):
35        self.n_samples = n_samples
36        self.n_elems = n_elems

Size of the dataset

38    def __len__(self):
42        return self.n_samples

Generate a sample

44    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:

Empty vector

50        x = torch.zeros((self.n_elems,))

Number of non-zero elements - a random number between and total number of elements

52        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()

Fill non-zero elements with 's and 's

54        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1

Randomly permute the elements

56        x = x[torch.randperm(self.n_elems)]

The parity

59        y = (x == 1.).sum() % 2

62        return x, y