This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.
Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.
Download data from Quick, Draw! Dataset. There is a link to download npz
files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz
file(s) in data/sketch
folder. This code is configured to use bicycle
dataset. You can change this in configurations.
Took help from PyTorch Sketch RNN project by Alexis David Jacq
32import math
33from typing import Optional, Tuple, Any
34
35import einops
36import numpy as np
37from matplotlib import pyplot as plt
38
39import torch
40import torch.nn as nn
41from labml import lab, experiment, tracker, monit
42from labml_nn.helpers.device import DeviceConfigs
43from labml_nn.helpers.optimizer import OptimizerConfigs
44from labml_nn.helpers.trainer import TrainValidConfigs, BatchIndex
45from torch import optim
46from torch.utils.data import Dataset, DataLoader
49class StrokesDataset(Dataset):
dataset
is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (, ) and the last integer represents the state of the pen, if it's touching the paper and otherwise.
56 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
66 data = []
We iterate through each of the sequences and filter
68 for seq in dataset:
Filter if the length of the sequence of strokes is within our range
70 if 10 < len(seq) <= max_seq_length:
Clamp , to
72 seq = np.minimum(seq, 1000)
73 seq = np.maximum(seq, -1000)
Convert to a floating point array and add to data
75 seq = np.array(seq, dtype=np.float32)
76 data.append(seq)
We then calculate the scaling factor which is the standard deviation of (, ) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to .
82 if scale is None:
83 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
84 self.scale = scale
Get the longest sequence length among all sequences
87 longest_seq_len = max([len(seq) for seq in data])
We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector . Only one of is and the others are . They represent pen down, pen up and end-of-sequence in that order. is if the pen touches the paper in the next step. is if the pen doesn't touch the paper in the next step. is if it is the end of the drawing.
97 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1]
and predicts next step.
100 self.mask = torch.zeros(len(data), longest_seq_len + 1)
101
102 for i, seq in enumerate(data):
103 seq = torch.from_numpy(seq)
104 len_seq = len(seq)
Scale and set
106 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
108 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
110 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
112 self.data[i, len_seq + 1:, 4] = 1
Mask is on until end of sequence
114 self.mask[i, :len_seq + 1] = 1
Start-of-sequence is
117 self.data[:, 0, 2] = 1
Size of the dataset
119 def __len__(self):
121 return len(self.data)
Get a sample
123 def __getitem__(self, idx: int):
125 return self.data[idx], self.mask[idx]
The mixture is represented by and . This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.
128class BivariateGaussianMixture:
138 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
139 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
140 self.pi_logits = pi_logits
141 self.mu_x = mu_x
142 self.mu_y = mu_y
143 self.sigma_x = sigma_x
144 self.sigma_y = sigma_y
145 self.rho_xy = rho_xy
Number of distributions in the mixture,
147 @property
148 def n_distributions(self):
150 return self.pi_logits.shape[-1]
Adjust by temperature
152 def set_temperature(self, temperature: float):
157 self.pi_logits /= temperature
159 self.sigma_x *= math.sqrt(temperature)
161 self.sigma_y *= math.sqrt(temperature)
163 def get_distribution(self):
Clamp , and to avoid getting NaN
s
165 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
166 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
167 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
Get means
170 mean = torch.stack([self.mu_x, self.mu_y], -1)
Get covariance matrix
172 cov = torch.stack([
173 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
174 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
175 ], -1)
176 cov = cov.view(*sigma_y.shape, 2, 2)
Create bi-variate normal distribution.
📝 It would be efficient to scale_tril
matrix as [[a, 0], [b, c]]
where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.
187 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
Create categorical distribution from logits
190 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
193 return cat_dist, multi_dist
196class EncoderRNN(nn.Module):
203 def __init__(self, d_z: int, enc_hidden_size: int):
204 super().__init__()
Create a bidirectional LSTM taking a sequence of as input.
207 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
Head to get
209 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
Head to get
211 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
213 def forward(self, inputs: torch.Tensor, state=None):
The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want.
220 _, (hidden, cell) = self.lstm(inputs.float(), state)
The state has shape [2, batch_size, hidden_size]
, where the first dimension is the direction. We rearrange it to get
224 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
227 mu = self.mu_head(hidden)
229 sigma_hat = self.sigma_head(hidden)
231 sigma = torch.exp(sigma_hat / 2.)
Sample
234 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
237 return z, mu, sigma_hat
240class DecoderRNN(nn.Module):
247 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
248 super().__init__()
LSTM takes as input
250 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
Initial state of the LSTM is . init_state
is the linear transformation for this
254 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
This layer produces outputs for each of the n_distributions
. Each distribution needs six parameters
259 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
This head is for the logits
262 self.q_head = nn.Linear(dec_hidden_size, 3)
This is to calculate where
265 self.q_log_softmax = nn.LogSoftmax(-1)
These parameters are stored for future reference
268 self.n_distributions = n_distributions
269 self.dec_hidden_size = dec_hidden_size
271 def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
Calculate the initial state
273 if state is None:
275 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
and c
have shapes [batch_size, lstm_size]
. We want to shape them to [1, batch_size, lstm_size]
because that's the shape used in LSTM.
278 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
Run the LSTM
281 outputs, state = self.lstm(x, state)
Get
284 q_logits = self.q_log_softmax(self.q_head(outputs))
Get . torch.split
splits the output into 6 tensors of size self.n_distribution
across dimension 2
.
290 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
291 torch.split(self.mixtures(outputs), self.n_distributions, 2)
Create a bi-variate Gaussian mixture and where and
is the categorical probabilities of choosing the distribution out of the mixture .
304 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
305 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
308 return dist, q_logits, state
311class ReconstructionLoss(nn.Module):
316 def forward(self, mask: torch.Tensor, target: torch.Tensor,
317 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
Get and
319 pi, mix = dist.get_distribution()
target
has shape [seq_len, batch_size, 5]
where the last dimension is the features . We want to get y and get the probabilities from each of the distributions in the mixture .
xy
will have shape [seq_len, batch_size, n_distributions, 2]
326 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
Calculate the probabilities
332 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
Although probs
has (longest_seq_len
) elements, the sum is only taken upto because the rest is masked out.
It might feel like we should be taking the sum and dividing by and not , but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction when we divide by
341 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
344 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
347 return loss_stroke + loss_pen
350class KLDivLoss(nn.Module):
357 def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
359 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
362class Sampler:
369 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
370 self.decoder = decoder
371 self.encoder = encoder
373 def sample(self, data: torch.Tensor, temperature: float):
375 longest_seq_len = len(data)
Get from the encoder
378 z, _, _ = self.encoder(data)
Start-of-sequence stroke is
381 s = data.new_tensor([0, 0, 1, 0, 0])
382 seq = [s]
Initial decoder is None
. The decoder will initialize it to
385 state = None
We don't need gradients
388 with torch.no_grad():
Sample strokes
390 for i in range(longest_seq_len):
is the input to the decoder
392 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
Get , , and the next state from the decoder
395 dist, q_logits, state = self.decoder(data, z, state)
Sample a stroke
397 s = self._sample_step(dist, q_logits, temperature)
Add the new stroke to the sequence of strokes
399 seq.append(s)
Stop sampling if . This indicates that sketching has stopped
401 if s[4] == 1:
402 break
Create a PyTorch tensor of the sequence of strokes
405 seq = torch.stack(seq)
Plot the sequence of strokes
408 self.plot(seq)
410 @staticmethod
411 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
Set temperature for sampling. This is implemented in class BivariateGaussianMixture
.
413 dist.set_temperature(temperature)
Get temperature adjusted and
415 pi, mix = dist.get_distribution()
Sample from the index of the distribution to use from the mixture
417 idx = pi.sample()[0, 0]
Create categorical distribution with log-probabilities q_logits
or
420 q = torch.distributions.Categorical(logits=q_logits / temperature)
Sample from
422 q_idx = q.sample()[0, 0]
Sample from the normal distributions in the mixture and pick the one indexed by idx
425 xy = mix.sample()[0, 0, idx]
Create an empty stroke
428 stroke = q_logits.new_zeros(5)
Set
430 stroke[:2] = xy
Set
432 stroke[q_idx + 2] = 1
434 return stroke
436 @staticmethod
437 def plot(seq: torch.Tensor):
Take the cumulative sums of to get
439 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
Create a new numpy array of the form
441 seq[:, 2] = seq[:, 3]
442 seq = seq[:, 0:3].detach().cpu().numpy()
Split the array at points where is . i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.
447 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
Plot each sequence of strokes
449 for s in strokes:
450 plt.plot(s[:, 0], -s[:, 1])
Don't show axes
452 plt.axis('off')
Show the plot
454 plt.show()
457class Configs(TrainValidConfigs):
Device configurations to pick the device to run the experiment
465 device: torch.device = DeviceConfigs()
467 encoder: EncoderRNN
468 decoder: DecoderRNN
469 optimizer: optim.Adam
470 sampler: Sampler
471
472 dataset_name: str
473 train_loader: DataLoader
474 valid_loader: DataLoader
475 train_dataset: StrokesDataset
476 valid_dataset: StrokesDataset
Encoder and decoder sizes
479 enc_hidden_size = 256
480 dec_hidden_size = 512
Batch size
483 batch_size = 100
Number of features in
486 d_z = 128
Number of distributions in the mixture,
488 n_distributions = 20
Weight of KL divergence loss,
491 kl_div_loss_weight = 0.5
Gradient clipping
493 grad_clip = 1.
Temperature for sampling
495 temperature = 0.4
Filter out stroke sequences longer than
498 max_seq_length = 200
499
500 epochs = 100
501
502 kl_div_loss = KLDivLoss()
503 reconstruction_loss = ReconstructionLoss()
505 def init(self):
Initialize encoder & decoder
507 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
508 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
Set optimizer. Things like type of optimizer and learning rate are configurable
511 optimizer = OptimizerConfigs()
512 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
513 self.optimizer = optimizer
Create sampler
516 self.sampler = Sampler(self.encoder, self.decoder)
npz
file path is data/sketch/[DATASET NAME].npz
519 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
Load the numpy file
521 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
Create training dataset
524 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
Create validation dataset
526 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
Create training data loader
529 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
Create validation data loader
531 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
Configure the tracker to print the total train/validation loss
534 tracker.set_scalar("loss.total.*", True)
535
536 self.state_modules = []
538 def step(self, batch: Any, batch_idx: BatchIndex):
539 self.encoder.train(self.mode.is_train)
540 self.decoder.train(self.mode.is_train)
Move data
and mask
to device and swap the sequence and batch dimensions. data
will have shape [seq_len, batch_size, 5]
and mask
will have shape [seq_len, batch_size]
.
545 data = batch[0].to(self.device).transpose(0, 1)
546 mask = batch[1].to(self.device).transpose(0, 1)
Increment step in training mode
549 if self.mode.is_train:
550 tracker.add_global_step(len(data))
Encode the sequence of strokes
553 with monit.section("encoder"):
Get , , and
555 z, mu, sigma_hat = self.encoder(data)
Decode the mixture of distributions and
558 with monit.section("decoder"):
Concatenate
560 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
561 inputs = torch.cat([data[:-1], z_stack], 2)
Get mixture of distributions and
563 dist, q_logits, _ = self.decoder(inputs, z, None)
Compute the loss
566 with monit.section('loss'):
568 kl_loss = self.kl_div_loss(sigma_hat, mu)
570 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
572 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
Track losses
575 tracker.add("loss.kl.", kl_loss)
576 tracker.add("loss.reconstruction.", reconstruction_loss)
577 tracker.add("loss.total.", loss)
Only if we are in training state
580 if self.mode.is_train:
Run optimizer
582 with monit.section('optimize'):
Set grad
to zero
584 self.optimizer.zero_grad()
Compute gradients
586 loss.backward()
Log model parameters and gradients
588 if batch_idx.is_last:
589 tracker.add(encoder=self.encoder, decoder=self.decoder)
Clip gradients
591 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
592 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
Optimize
594 self.optimizer.step()
595
596 tracker.save()
598 def sample(self):
Randomly pick a sample from validation dataset to encoder
600 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
Add batch dimension and move it to device
602 data = data.unsqueeze(1).to(self.device)
Sample
604 self.sampler.sample(data, self.temperature)
607def main():
608 configs = Configs()
609 experiment.create(name="sketch_rnn")
Pass a dictionary of configurations
612 experiment.configs(configs, {
613 'optimizer.optimizer': 'Adam',
We use a learning rate of 1e-3
because we can see results faster. Paper had suggested 1e-4
.
616 'optimizer.learning_rate': 1e-3,
Name of the dataset
618 'dataset_name': 'bicycle',
Number of inner iterations within an epoch to switch between training, validation and sampling.
620 'inner_iterations': 10
621 })
622
623 with experiment.start():
Run the experiment
625 configs.run()
626
627
628if __name__ == "__main__":
629 main()