Sketch RNN

This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.

Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.

Getting data

Download data from Quick, Draw! Dataset. There is a link to download npz files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz file(s) in data/sketch folder. This code is configured to use bicycle dataset. You can change this in configurations.

Acknowledgements

Took help from PyTorch Sketch RNN project by Alexis David Jacq

32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex

Dataset

This class loads and pre-processes the data.

50class StrokesDataset(Dataset):

dataset is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (, ) and the last integer represents the state of the pen, if it's touching the paper and otherwise.

57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67        data = []

We iterate through each of the sequences and filter

69        for seq in dataset:

Filter if the length of the sequence of strokes is within our range

71            if 10 < len(seq) <= max_seq_length:

Clamp , to

73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)

Convert to a floating point array and add to data

76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)

We then calculate the scaling factor which is the standard deviation of (, ) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to .

83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scale

Get the longest sequence length among all sequences

88        longest_seq_len = max([len(seq) for seq in data])

We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector . Only one of is and the others are . They represent pen down, pen up and end-of-sequence in that order. is if the pen touches the paper in the next step. is if the pen doesn't touch the paper in the next step. is if it is the end of the drawing.

98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1] and predicts next step.

101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)

Scale and set

107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

113            self.data[i, len_seq + 1:, 4] = 1

Mask is on until end of sequence

115            self.mask[i, :len_seq + 1] = 1

Start-of-sequence is

118        self.data[:, 0, 2] = 1

Size of the dataset

120    def __len__(self):
122        return len(self.data)

Get a sample

124    def __getitem__(self, idx: int):
126        return self.data[idx], self.mask[idx]

Bi-variate Gaussian mixture

The mixture is represented by and . This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.

129class BivariateGaussianMixture:
139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xy

Number of distributions in the mixture,

148    @property
149    def n_distributions(self):
151        return self.pi_logits.shape[-1]

Adjust by temperature

153    def set_temperature(self, temperature: float):

158        self.pi_logits /= temperature

160        self.sigma_x *= math.sqrt(temperature)

162        self.sigma_y *= math.sqrt(temperature)
164    def get_distribution(self):

Clamp , and to avoid getting NaN s

166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

Get means

171        mean = torch.stack([self.mu_x, self.mu_y], -1)

Get covariance matrix

173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)

Create bi-variate normal distribution.

📝 It would be efficient to scale_tril matrix as [[a, 0], [b, c]] where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.

188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

Create categorical distribution from logits

191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)

194        return cat_dist, multi_dist

Encoder module

This consists of a bidirectional LSTM

197class EncoderRNN(Module):
204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()

Create a bidirectional LSTM taking a sequence of as input.

208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

Head to get

210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

Head to get

212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214    def forward(self, inputs: torch.Tensor, state=None):

The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want.

221        _, (hidden, cell) = self.lstm(inputs.float(), state)

The state has shape [2, batch_size, hidden_size] , where the first dimension is the direction. We rearrange it to get

225        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

228        mu = self.mu_head(hidden)

230        sigma_hat = self.sigma_head(hidden)

232        sigma = torch.exp(sigma_hat / 2.)

Sample

235        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))

238        return z, mu, sigma_hat

Decoder module

This consists of a LSTM

241class DecoderRNN(Module):
248    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249        super().__init__()

LSTM takes as input

251        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

Initial state of the LSTM is . init_state is the linear transformation for this

255        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

This layer produces outputs for each of the n_distributions . Each distribution needs six parameters

260        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

This head is for the logits

263        self.q_head = nn.Linear(dec_hidden_size, 3)

This is to calculate where

266        self.q_log_softmax = nn.LogSoftmax(-1)

These parameters are stored for future reference

269        self.n_distributions = n_distributions
270        self.dec_hidden_size = dec_hidden_size
272    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

Calculate the initial state

274        if state is None:

276            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

h and c have shapes [batch_size, lstm_size] . We want to shape them to [1, batch_size, lstm_size] because that's the shape used in LSTM.

279            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

Run the LSTM

282        outputs, state = self.lstm(x, state)

Get

285        q_logits = self.q_log_softmax(self.q_head(outputs))

Get . torch.split splits the output into 6 tensors of size self.n_distribution across dimension 2 .

291        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292            torch.split(self.mixtures(outputs), self.n_distributions, 2)

Create a bi-variate Gaussian mixture and where and

is the categorical probabilities of choosing the distribution out of the mixture .

305        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))

309        return dist, q_logits, state

Reconstruction Loss

312class ReconstructionLoss(Module):
317    def forward(self, mask: torch.Tensor, target: torch.Tensor,
318                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

Get and

320        pi, mix = dist.get_distribution()

target has shape [seq_len, batch_size, 5] where the last dimension is the features . We want to get y and get the probabilities from each of the distributions in the mixture .

xy will have shape [seq_len, batch_size, n_distributions, 2]

327        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

Calculate the probabilities

333        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

Although probs has (longest_seq_len ) elements, the sum is only taken upto because the rest is masked out.

It might feel like we should be taking the sum and dividing by and not , but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction when we divide by

342        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

345        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

348        return loss_stroke + loss_pen

KL-Divergence loss

This calculates the KL divergence between a given normal distribution and

351class KLDivLoss(Module):
358    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

360        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

Sampler

This samples a sketch from the decoder and plots it

363class Sampler:
370    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371        self.decoder = decoder
372        self.encoder = encoder
374    def sample(self, data: torch.Tensor, temperature: float):

376        longest_seq_len = len(data)

Get from the encoder

379        z, _, _ = self.encoder(data)

Start-of-sequence stroke is

382        s = data.new_tensor([0, 0, 1, 0, 0])
383        seq = [s]

Initial decoder is None . The decoder will initialize it to

386        state = None

We don't need gradients

389        with torch.no_grad():

Sample strokes

391            for i in range(longest_seq_len):

is the input to the decoder

393                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

Get , , and the next state from the decoder

396                dist, q_logits, state = self.decoder(data, z, state)

Sample a stroke

398                s = self._sample_step(dist, q_logits, temperature)

Add the new stroke to the sequence of strokes

400                seq.append(s)

Stop sampling if . This indicates that sketching has stopped

402                if s[4] == 1:
403                    break

Create a PyTorch tensor of the sequence of strokes

406        seq = torch.stack(seq)

Plot the sequence of strokes

409        self.plot(seq)
411    @staticmethod
412    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

Set temperature for sampling. This is implemented in class BivariateGaussianMixture .

414        dist.set_temperature(temperature)

Get temperature adjusted and

416        pi, mix = dist.get_distribution()

Sample from the index of the distribution to use from the mixture

418        idx = pi.sample()[0, 0]

Create categorical distribution with log-probabilities q_logits or

421        q = torch.distributions.Categorical(logits=q_logits / temperature)

Sample from

423        q_idx = q.sample()[0, 0]

Sample from the normal distributions in the mixture and pick the one indexed by idx

426        xy = mix.sample()[0, 0, idx]

Create an empty stroke

429        stroke = q_logits.new_zeros(5)

Set

431        stroke[:2] = xy

Set

433        stroke[q_idx + 2] = 1

435        return stroke
437    @staticmethod
438    def plot(seq: torch.Tensor):

Take the cumulative sums of to get

440        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

Create a new numpy array of the form

442        seq[:, 2] = seq[:, 3]
443        seq = seq[:, 0:3].detach().cpu().numpy()

Split the array at points where is . i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.

448        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

Plot each sequence of strokes

450        for s in strokes:
451            plt.plot(s[:, 0], -s[:, 1])

Don't show axes

453        plt.axis('off')

Show the plot

455        plt.show()

Configurations

These are default configurations which can later be adjusted by passing a dict .

458class Configs(TrainValidConfigs):

Device configurations to pick the device to run the experiment

466    device: torch.device = DeviceConfigs()

468    encoder: EncoderRNN
469    decoder: DecoderRNN
470    optimizer: optim.Adam
471    sampler: Sampler
472
473    dataset_name: str
474    train_loader: DataLoader
475    valid_loader: DataLoader
476    train_dataset: StrokesDataset
477    valid_dataset: StrokesDataset

Encoder and decoder sizes

480    enc_hidden_size = 256
481    dec_hidden_size = 512

Batch size

484    batch_size = 100

Number of features in

487    d_z = 128

Number of distributions in the mixture,

489    n_distributions = 20

Weight of KL divergence loss,

492    kl_div_loss_weight = 0.5

Gradient clipping

494    grad_clip = 1.

Temperature for sampling

496    temperature = 0.4

Filter out stroke sequences longer than

499    max_seq_length = 200
500
501    epochs = 100
502
503    kl_div_loss = KLDivLoss()
504    reconstruction_loss = ReconstructionLoss()
506    def init(self):

Initialize encoder & decoder

508        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

Set optimizer. Things like type of optimizer and learning rate are configurable

512        optimizer = OptimizerConfigs()
513        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514        self.optimizer = optimizer

Create sampler

517        self.sampler = Sampler(self.encoder, self.decoder)

npz file path is data/sketch/[DATASET NAME].npz

520        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

Load the numpy file

522        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

Create training dataset

525        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

Create validation dataset

527        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

Create training data loader

530        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

Create validation data loader

532        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

Add hooks to monitor layer outputs on Tensorboard

535        hook_model_outputs(self.mode, self.encoder, 'encoder')
536        hook_model_outputs(self.mode, self.decoder, 'decoder')

Configure the tracker to print the total train/validation loss

539        tracker.set_scalar("loss.total.*", True)
540
541        self.state_modules = []
543    def step(self, batch: Any, batch_idx: BatchIndex):
544        self.encoder.train(self.mode.is_train)
545        self.decoder.train(self.mode.is_train)

Move data and mask to device and swap the sequence and batch dimensions. data will have shape [seq_len, batch_size, 5] and mask will have shape [seq_len, batch_size] .

550        data = batch[0].to(self.device).transpose(0, 1)
551        mask = batch[1].to(self.device).transpose(0, 1)

Increment step in training mode

554        if self.mode.is_train:
555            tracker.add_global_step(len(data))

Encode the sequence of strokes

558        with monit.section("encoder"):

Get , , and

560            z, mu, sigma_hat = self.encoder(data)

Decode the mixture of distributions and

563        with monit.section("decoder"):

Concatenate

565            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566            inputs = torch.cat([data[:-1], z_stack], 2)

Get mixture of distributions and

568            dist, q_logits, _ = self.decoder(inputs, z, None)

Compute the loss

571        with monit.section('loss'):

573            kl_loss = self.kl_div_loss(sigma_hat, mu)

575            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

577            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

Track losses

580            tracker.add("loss.kl.", kl_loss)
581            tracker.add("loss.reconstruction.", reconstruction_loss)
582            tracker.add("loss.total.", loss)

Only if we are in training state

585        if self.mode.is_train:

Run optimizer

587            with monit.section('optimize'):

Set grad to zero

589                self.optimizer.zero_grad()

Compute gradients

591                loss.backward()

Log model parameters and gradients

593                if batch_idx.is_last:
594                    tracker.add(encoder=self.encoder, decoder=self.decoder)

Clip gradients

596                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

Optimize

599                self.optimizer.step()
600
601        tracker.save()
603    def sample(self):

Randomly pick a sample from validation dataset to encoder

605        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

Add batch dimension and move it to device

607        data = data.unsqueeze(1).to(self.device)

Sample

609        self.sampler.sample(data, self.temperature)
612def main():
613    configs = Configs()
614    experiment.create(name="sketch_rnn")

Pass a dictionary of configurations

617    experiment.configs(configs, {
618        'optimizer.optimizer': 'Adam',

We use a learning rate of 1e-3 because we can see results faster. Paper had suggested 1e-4 .

621        'optimizer.learning_rate': 1e-3,

Name of the dataset

623        'dataset_name': 'bicycle',

Number of inner iterations within an epoch to switch between training, validation and sampling.

625        'inner_iterations': 10
626    })
627
628    with experiment.start():

Run the experiment

630        configs.run()
631
632
633if __name__ == "__main__":
634    main()