Sketch RNN

This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.

Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.

Getting data

Download data from Quick, Draw! Dataset. There is a link to download npz files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz file(s) in data/sketch folder. This code is configured to use bicycle dataset. You can change this in configurations.

Acknowledgements

Took help from PyTorch Sketch RNN project by Alexis David Jacq

32import math
33from typing import Optional, Tuple, Any
34
35import einops
36import numpy as np
37from matplotlib import pyplot as plt
38
39import torch
40import torch.nn as nn
41from labml import lab, experiment, tracker, monit
42from labml_nn.helpers.device import DeviceConfigs
43from labml_nn.helpers.optimizer import OptimizerConfigs
44from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
45from torch import optim
46from torch.utils.data import Dataset, DataLoader

Dataset

This class loads and pre-processes the data.

49class StrokesDataset(Dataset):

dataset is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (, ) and the last integer represents the state of the pen, if it's touching the paper and otherwise.

56    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
66        data = []

We iterate through each of the sequences and filter

68        for seq in dataset:

Filter if the length of the sequence of strokes is within our range

70            if 10 < len(seq) <= max_seq_length:

Clamp , to

72                seq = np.minimum(seq, 1000)
73                seq = np.maximum(seq, -1000)

Convert to a floating point array and add to data

75                seq = np.array(seq, dtype=np.float32)
76                data.append(seq)

We then calculate the scaling factor which is the standard deviation of (, ) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to .

82        if scale is None:
83            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
84        self.scale = scale

Get the longest sequence length among all sequences

87        longest_seq_len = max([len(seq) for seq in data])

We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector . Only one of is and the others are . They represent pen down, pen up and end-of-sequence in that order. is if the pen touches the paper in the next step. is if the pen doesn't touch the paper in the next step. is if it is the end of the drawing.

97        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1] and predicts next step.

100        self.mask = torch.zeros(len(data), longest_seq_len + 1)
101
102        for i, seq in enumerate(data):
103            seq = torch.from_numpy(seq)
104            len_seq = len(seq)

Scale and set

106            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

108            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

110            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

112            self.data[i, len_seq + 1:, 4] = 1

Mask is on until end of sequence

114            self.mask[i, :len_seq + 1] = 1

Start-of-sequence is

117        self.data[:, 0, 2] = 1

Size of the dataset

119    def __len__(self):
121        return len(self.data)

Get a sample

123    def __getitem__(self, idx: int):
125        return self.data[idx], self.mask[idx]

Bi-variate Gaussian mixture

The mixture is represented by and . This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.

128class BivariateGaussianMixture:
138    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
139                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
140        self.pi_logits = pi_logits
141        self.mu_x = mu_x
142        self.mu_y = mu_y
143        self.sigma_x = sigma_x
144        self.sigma_y = sigma_y
145        self.rho_xy = rho_xy

Number of distributions in the mixture,

147    @property
148    def n_distributions(self):
150        return self.pi_logits.shape[-1]

Adjust by temperature

152    def set_temperature(self, temperature: float):

157        self.pi_logits /= temperature

159        self.sigma_x *= math.sqrt(temperature)

161        self.sigma_y *= math.sqrt(temperature)
163    def get_distribution(self):

Clamp , and to avoid getting NaN s

165        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
166        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
167        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

Get means

170        mean = torch.stack([self.mu_x, self.mu_y], -1)

Get covariance matrix

172        cov = torch.stack([
173            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
174            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
175        ], -1)
176        cov = cov.view(*sigma_y.shape, 2, 2)

Create bi-variate normal distribution.

📝 It would be efficient to scale_tril matrix as [[a, 0], [b, c]] where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.

187        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

Create categorical distribution from logits

190        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)

193        return cat_dist, multi_dist

Encoder module

This consists of a bidirectional LSTM

196class EncoderRNN(nn.Module):
203    def __init__(self, d_z: int, enc_hidden_size: int):
204        super().__init__()

Create a bidirectional LSTM taking a sequence of as input.

207        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

Head to get

209        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

Head to get

211        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
213    def forward(self, inputs: torch.Tensor, state=None):

The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want.

220        _, (hidden, cell) = self.lstm(inputs.float(), state)

The state has shape [2, batch_size, hidden_size] , where the first dimension is the direction. We rearrange it to get

224        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

227        mu = self.mu_head(hidden)

229        sigma_hat = self.sigma_head(hidden)

231        sigma = torch.exp(sigma_hat / 2.)

Sample

234        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))

237        return z, mu, sigma_hat

Decoder module

This consists of a LSTM

240class DecoderRNN(nn.Module):
247    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
248        super().__init__()

LSTM takes as input

250        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

Initial state of the LSTM is . init_state is the linear transformation for this

254        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

This layer produces outputs for each of the n_distributions . Each distribution needs six parameters

259        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

This head is for the logits

262        self.q_head = nn.Linear(dec_hidden_size, 3)

This is to calculate where

265        self.q_log_softmax = nn.LogSoftmax(-1)

These parameters are stored for future reference

268        self.n_distributions = n_distributions
269        self.dec_hidden_size = dec_hidden_size
271    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

Calculate the initial state

273        if state is None:

275            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

h and c have shapes [batch_size, lstm_size] . We want to shape them to [1, batch_size, lstm_size] because that's the shape used in LSTM.

278            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

Run the LSTM

281        outputs, state = self.lstm(x, state)

Get

284        q_logits = self.q_log_softmax(self.q_head(outputs))

Get . torch.split splits the output into 6 tensors of size self.n_distribution across dimension 2 .

290        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
291            torch.split(self.mixtures(outputs), self.n_distributions, 2)

Create a bi-variate Gaussian mixture and where and

is the categorical probabilities of choosing the distribution out of the mixture .

304        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
305                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))

308        return dist, q_logits, state

Reconstruction Loss

311class ReconstructionLoss(nn.Module):
316    def forward(self, mask: torch.Tensor, target: torch.Tensor,
317                dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

Get and

319        pi, mix = dist.get_distribution()

target has shape [seq_len, batch_size, 5] where the last dimension is the features . We want to get y and get the probabilities from each of the distributions in the mixture .

xy will have shape [seq_len, batch_size, n_distributions, 2]

326        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

Calculate the probabilities

332        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

Although probs has (longest_seq_len ) elements, the sum is only taken upto because the rest is masked out.

It might feel like we should be taking the sum and dividing by and not , but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction when we divide by

341        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

344        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

347        return loss_stroke + loss_pen

KL-Divergence loss

This calculates the KL divergence between a given normal distribution and

350class KLDivLoss(nn.Module):
357    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

359        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

Sampler

This samples a sketch from the decoder and plots it

362class Sampler:
369    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
370        self.decoder = decoder
371        self.encoder = encoder
373    def sample(self, data: torch.Tensor, temperature: float):

375        longest_seq_len = len(data)

Get from the encoder

378        z, _, _ = self.encoder(data)

Start-of-sequence stroke is

381        s = data.new_tensor([0, 0, 1, 0, 0])
382        seq = [s]

Initial decoder is None . The decoder will initialize it to

385        state = None

We don't need gradients

388        with torch.no_grad():

Sample strokes

390            for i in range(longest_seq_len):

is the input to the decoder

392                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

Get , , and the next state from the decoder

395                dist, q_logits, state = self.decoder(data, z, state)

Sample a stroke

397                s = self._sample_step(dist, q_logits, temperature)

Add the new stroke to the sequence of strokes

399                seq.append(s)

Stop sampling if . This indicates that sketching has stopped

401                if s[4] == 1:
402                    break

Create a PyTorch tensor of the sequence of strokes

405        seq = torch.stack(seq)

Plot the sequence of strokes

408        self.plot(seq)
410    @staticmethod
411    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

Set temperature for sampling. This is implemented in class BivariateGaussianMixture .

413        dist.set_temperature(temperature)

Get temperature adjusted and

415        pi, mix = dist.get_distribution()

Sample from the index of the distribution to use from the mixture

417        idx = pi.sample()[0, 0]

Create categorical distribution with log-probabilities q_logits or

420        q = torch.distributions.Categorical(logits=q_logits / temperature)

Sample from

422        q_idx = q.sample()[0, 0]

Sample from the normal distributions in the mixture and pick the one indexed by idx

425        xy = mix.sample()[0, 0, idx]

Create an empty stroke

428        stroke = q_logits.new_zeros(5)

Set

430        stroke[:2] = xy

Set

432        stroke[q_idx + 2] = 1

434        return stroke
436    @staticmethod
437    def plot(seq: torch.Tensor):

Take the cumulative sums of to get

439        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

Create a new numpy array of the form

441        seq[:, 2] = seq[:, 3]
442        seq = seq[:, 0:3].detach().cpu().numpy()

Split the array at points where is . i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.

447        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

Plot each sequence of strokes

449        for s in strokes:
450            plt.plot(s[:, 0], -s[:, 1])

Don't show axes

452        plt.axis('off')

Show the plot

454        plt.show()

Configurations

These are default configurations which can later be adjusted by passing a dict .

457class Configs(TrainValidConfigs):

Device configurations to pick the device to run the experiment

465    device: torch.device = DeviceConfigs()

467    encoder: EncoderRNN
468    decoder: DecoderRNN
469    optimizer: optim.Adam
470    sampler: Sampler
471
472    dataset_name: str
473    train_loader: DataLoader
474    valid_loader: DataLoader
475    train_dataset: StrokesDataset
476    valid_dataset: StrokesDataset

Encoder and decoder sizes

479    enc_hidden_size = 256
480    dec_hidden_size = 512

Batch size

483    batch_size = 100

Number of features in

486    d_z = 128

Number of distributions in the mixture,

488    n_distributions = 20

Weight of KL divergence loss,

491    kl_div_loss_weight = 0.5

Gradient clipping

493    grad_clip = 1.

Temperature for sampling

495    temperature = 0.4

Filter out stroke sequences longer than

498    max_seq_length = 200
499
500    epochs = 100
501
502    kl_div_loss = KLDivLoss()
503    reconstruction_loss = ReconstructionLoss()
505    def init(self):

Initialize encoder & decoder

507        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
508        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

Set optimizer. Things like type of optimizer and learning rate are configurable

511        optimizer = OptimizerConfigs()
512        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
513        self.optimizer = optimizer

Create sampler

516        self.sampler = Sampler(self.encoder, self.decoder)

npz file path is data/sketch/[DATASET NAME].npz

519        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

Load the numpy file

521        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

Create training dataset

524        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

Create validation dataset

526        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

Create training data loader

529        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

Create validation data loader

531        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

Configure the tracker to print the total train/validation loss

534        tracker.set_scalar("loss.total.*", True)
535
536        self.state_modules = []
538    def step(self, batch: Any, batch_idx: BatchIndex):
539        self.encoder.train(self.mode.is_train)
540        self.decoder.train(self.mode.is_train)

Move data and mask to device and swap the sequence and batch dimensions. data will have shape [seq_len, batch_size, 5] and mask will have shape [seq_len, batch_size] .

545        data = batch[0].to(self.device).transpose(0, 1)
546        mask = batch[1].to(self.device).transpose(0, 1)

Increment step in training mode

549        if self.mode.is_train:
550            tracker.add_global_step(len(data))

Encode the sequence of strokes

553        with monit.section("encoder"):

Get , , and

555            z, mu, sigma_hat = self.encoder(data)

Decode the mixture of distributions and

558        with monit.section("decoder"):

Concatenate

560            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
561            inputs = torch.cat([data[:-1], z_stack], 2)

Get mixture of distributions and

563            dist, q_logits, _ = self.decoder(inputs, z, None)

Compute the loss

566        with monit.section('loss'):

568            kl_loss = self.kl_div_loss(sigma_hat, mu)

570            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

572            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

Track losses

575            tracker.add("loss.kl.", kl_loss)
576            tracker.add("loss.reconstruction.", reconstruction_loss)
577            tracker.add("loss.total.", loss)

Only if we are in training state

580        if self.mode.is_train:

Run optimizer

582            with monit.section('optimize'):

Set grad to zero

584                self.optimizer.zero_grad()

Compute gradients

586                loss.backward()

Log model parameters and gradients

588                if batch_idx.is_last:
589                    tracker.add(encoder=self.encoder, decoder=self.decoder)

Clip gradients

591                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
592                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

Optimize

594                self.optimizer.step()
595
596        tracker.save()
598    def sample(self):

Randomly pick a sample from validation dataset to encoder

600        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

Add batch dimension and move it to device

602        data = data.unsqueeze(1).to(self.device)

Sample

604        self.sampler.sample(data, self.temperature)
607def main():
608    configs = Configs()
609    experiment.create(name="sketch_rnn")

Pass a dictionary of configurations

612    experiment.configs(configs, {
613        'optimizer.optimizer': 'Adam',

We use a learning rate of 1e-3 because we can see results faster. Paper had suggested 1e-4 .

616        'optimizer.learning_rate': 1e-3,

Name of the dataset

618        'dataset_name': 'bicycle',

Number of inner iterations within an epoch to switch between training, validation and sampling.

620        'inner_iterations': 10
621    })
622
623    with experiment.start():

Run the experiment

625        configs.run()
626
627
628if __name__ == "__main__":
629    main()