This is an annotated PyTorch implementation of the paper A Neural Representation of Sketch Drawings.
Sketch RNN is a sequence-to-sequence variational auto-encoder. Both encoder and decoder are recurrent neural network models. It learns to reconstruct stroke based simple drawings, by predicting a series of strokes. Decoder predicts each stroke as a mixture of Gaussian's.
Download data from Quick, Draw! Dataset. There is a link to download npz
files in Sketch-RNN QuickDraw Dataset section of the readme. Place the downloaded npz
file(s) in data/sketch
folder. This code is configured to use bicycle
dataset. You can change this in configurations.
Took help from PyTorch Sketch RNN project by Alexis David Jacq
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
50class StrokesDataset(Dataset):
dataset
is a list of numpy arrays of shape seq_len, 3. It is a sequence of strokes, and each stroke is represented by 3 integers. First two are the displacements along x and y (, ) and the last integer represents the state of the pen, if it's touching the paper and otherwise.
57 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67 data = []
We iterate through each of the sequences and filter
69 for seq in dataset:
Filter if the length of the sequence of strokes is within our range
71 if 10 < len(seq) <= max_seq_length:
Clamp , to
73 seq = np.minimum(seq, 1000)
74 seq = np.maximum(seq, -1000)
Convert to a floating point array and add to data
76 seq = np.array(seq, dtype=np.float32)
77 data.append(seq)
We then calculate the scaling factor which is the standard deviation of (, ) combined. Paper notes that the mean is not adjusted for simplicity, since the mean is anyway close to .
83 if scale is None:
84 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85 self.scale = scale
Get the longest sequence length among all sequences
88 longest_seq_len = max([len(seq) for seq in data])
We initialize PyTorch data array with two extra steps for start-of-sequence (sos) and end-of-sequence (eos). Each step is a vector . Only one of is and the others are . They represent pen down, pen up and end-of-sequence in that order. is if the pen touches the paper in the next step. is if the pen doesn't touch the paper in the next step. is if it is the end of the drawing.
98 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
The mask array needs only one extra-step since it is for the outputs of the decoder, which takes in data[:-1]
and predicts next step.
101 self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103 for i, seq in enumerate(data):
104 seq = torch.from_numpy(seq)
105 len_seq = len(seq)
Scale and set
107 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
109 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
111 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
113 self.data[i, len_seq + 1:, 4] = 1
Mask is on until end of sequence
115 self.mask[i, :len_seq + 1] = 1
Start-of-sequence is
118 self.data[:, 0, 2] = 1
Size of the dataset
120 def __len__(self):
122 return len(self.data)
Get a sample
124 def __getitem__(self, idx: int):
126 return self.data[idx], self.mask[idx]
The mixture is represented by and . This class adjusts temperatures and creates the categorical and Gaussian distributions from the parameters.
129class BivariateGaussianMixture:
139 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141 self.pi_logits = pi_logits
142 self.mu_x = mu_x
143 self.mu_y = mu_y
144 self.sigma_x = sigma_x
145 self.sigma_y = sigma_y
146 self.rho_xy = rho_xy
Number of distributions in the mixture,
148 @property
149 def n_distributions(self):
151 return self.pi_logits.shape[-1]
Adjust by temperature
153 def set_temperature(self, temperature: float):
158 self.pi_logits /= temperature
160 self.sigma_x *= math.sqrt(temperature)
162 self.sigma_y *= math.sqrt(temperature)
164 def get_distribution(self):
Clamp , and to avoid getting NaN
s
166 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
Get means
171 mean = torch.stack([self.mu_x, self.mu_y], -1)
Get covariance matrix
173 cov = torch.stack([
174 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176 ], -1)
177 cov = cov.view(*sigma_y.shape, 2, 2)
Create bi-variate normal distribution.
📝 It would be efficient to scale_tril
matrix as [[a, 0], [b, c]]
where . But for simplicity we use co-variance matrix. This is a good resource if you want to read up more about bi-variate distributions, their co-variance matrix, and probability density function.
188 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
Create categorical distribution from logits
191 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194 return cat_dist, multi_dist
197class EncoderRNN(Module):
204 def __init__(self, d_z: int, enc_hidden_size: int):
205 super().__init__()
Create a bidirectional LSTM taking a sequence of as input.
208 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
Head to get
210 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
Head to get
212 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214 def forward(self, inputs: torch.Tensor, state=None):
The hidden state of the bidirectional LSTM is the concatenation of the output of the last token in the forward direction and first token in the reverse direction, which is what we want.
221 _, (hidden, cell) = self.lstm(inputs.float(), state)
The state has shape [2, batch_size, hidden_size]
, where the first dimension is the direction. We rearrange it to get
225 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
228 mu = self.mu_head(hidden)
230 sigma_hat = self.sigma_head(hidden)
232 sigma = torch.exp(sigma_hat / 2.)
Sample
235 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
238 return z, mu, sigma_hat
241class DecoderRNN(Module):
248 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249 super().__init__()
LSTM takes as input
251 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
Initial state of the LSTM is . init_state
is the linear transformation for this
255 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
This layer produces outputs for each of the n_distributions
. Each distribution needs six parameters
260 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
This head is for the logits
263 self.q_head = nn.Linear(dec_hidden_size, 3)
This is to calculate where
266 self.q_log_softmax = nn.LogSoftmax(-1)
These parameters are stored for future reference
269 self.n_distributions = n_distributions
270 self.dec_hidden_size = dec_hidden_size
272 def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
Calculate the initial state
274 if state is None:
276 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
and c
have shapes [batch_size, lstm_size]
. We want to shape them to [1, batch_size, lstm_size]
because that's the shape used in LSTM.
279 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
Run the LSTM
282 outputs, state = self.lstm(x, state)
Get
285 q_logits = self.q_log_softmax(self.q_head(outputs))
Get . torch.split
splits the output into 6 tensors of size self.n_distribution
across dimension 2
.
291 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292 torch.split(self.mixtures(outputs), self.n_distributions, 2)
Create a bi-variate Gaussian mixture and where and
is the categorical probabilities of choosing the distribution out of the mixture .
305 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
309 return dist, q_logits, state
312class ReconstructionLoss(Module):
317 def forward(self, mask: torch.Tensor, target: torch.Tensor,
318 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
Get and
320 pi, mix = dist.get_distribution()
target
has shape [seq_len, batch_size, 5]
where the last dimension is the features . We want to get y and get the probabilities from each of the distributions in the mixture .
xy
will have shape [seq_len, batch_size, n_distributions, 2]
327 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
Calculate the probabilities
333 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
Although probs
has (longest_seq_len
) elements, the sum is only taken upto because the rest is masked out.
It might feel like we should be taking the sum and dividing by and not , but this will give higher weight for individual predictions in shorter sequences. We give equal weight to each prediction when we divide by
342 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
345 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
348 return loss_stroke + loss_pen
351class KLDivLoss(Module):
358 def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
360 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
363class Sampler:
370 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371 self.decoder = decoder
372 self.encoder = encoder
374 def sample(self, data: torch.Tensor, temperature: float):
376 longest_seq_len = len(data)
Get from the encoder
379 z, _, _ = self.encoder(data)
Start-of-sequence stroke is
382 s = data.new_tensor([0, 0, 1, 0, 0])
383 seq = [s]
Initial decoder is None
. The decoder will initialize it to
386 state = None
We don't need gradients
389 with torch.no_grad():
Sample strokes
391 for i in range(longest_seq_len):
is the input to the decoder
393 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
Get , , and the next state from the decoder
396 dist, q_logits, state = self.decoder(data, z, state)
Sample a stroke
398 s = self._sample_step(dist, q_logits, temperature)
Add the new stroke to the sequence of strokes
400 seq.append(s)
Stop sampling if . This indicates that sketching has stopped
402 if s[4] == 1:
403 break
Create a PyTorch tensor of the sequence of strokes
406 seq = torch.stack(seq)
Plot the sequence of strokes
409 self.plot(seq)
411 @staticmethod
412 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
Set temperature for sampling. This is implemented in class BivariateGaussianMixture
.
414 dist.set_temperature(temperature)
Get temperature adjusted and
416 pi, mix = dist.get_distribution()
Sample from the index of the distribution to use from the mixture
418 idx = pi.sample()[0, 0]
Create categorical distribution with log-probabilities q_logits
or
421 q = torch.distributions.Categorical(logits=q_logits / temperature)
Sample from
423 q_idx = q.sample()[0, 0]
Sample from the normal distributions in the mixture and pick the one indexed by idx
426 xy = mix.sample()[0, 0, idx]
Create an empty stroke
429 stroke = q_logits.new_zeros(5)
Set
431 stroke[:2] = xy
Set
433 stroke[q_idx + 2] = 1
435 return stroke
437 @staticmethod
438 def plot(seq: torch.Tensor):
Take the cumulative sums of to get
440 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
Create a new numpy array of the form
442 seq[:, 2] = seq[:, 3]
443 seq = seq[:, 0:3].detach().cpu().numpy()
Split the array at points where is . i.e. split the array of strokes at the points where the pen is lifted from the paper. This gives a list of sequence of strokes.
448 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
Plot each sequence of strokes
450 for s in strokes:
451 plt.plot(s[:, 0], -s[:, 1])
Don't show axes
453 plt.axis('off')
Show the plot
455 plt.show()
458class Configs(TrainValidConfigs):
Device configurations to pick the device to run the experiment
466 device: torch.device = DeviceConfigs()
468 encoder: EncoderRNN
469 decoder: DecoderRNN
470 optimizer: optim.Adam
471 sampler: Sampler
472
473 dataset_name: str
474 train_loader: DataLoader
475 valid_loader: DataLoader
476 train_dataset: StrokesDataset
477 valid_dataset: StrokesDataset
Encoder and decoder sizes
480 enc_hidden_size = 256
481 dec_hidden_size = 512
Batch size
484 batch_size = 100
Number of features in
487 d_z = 128
Number of distributions in the mixture,
489 n_distributions = 20
Weight of KL divergence loss,
492 kl_div_loss_weight = 0.5
Gradient clipping
494 grad_clip = 1.
Temperature for sampling
496 temperature = 0.4
Filter out stroke sequences longer than
499 max_seq_length = 200
500
501 epochs = 100
502
503 kl_div_loss = KLDivLoss()
504 reconstruction_loss = ReconstructionLoss()
506 def init(self):
Initialize encoder & decoder
508 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
Set optimizer. Things like type of optimizer and learning rate are configurable
512 optimizer = OptimizerConfigs()
513 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514 self.optimizer = optimizer
Create sampler
517 self.sampler = Sampler(self.encoder, self.decoder)
npz
file path is data/sketch/[DATASET NAME].npz
520 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
Load the numpy file
522 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
Create training dataset
525 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
Create validation dataset
527 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
Create training data loader
530 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
Create validation data loader
532 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
Add hooks to monitor layer outputs on Tensorboard
535 hook_model_outputs(self.mode, self.encoder, 'encoder')
536 hook_model_outputs(self.mode, self.decoder, 'decoder')
Configure the tracker to print the total train/validation loss
539 tracker.set_scalar("loss.total.*", True)
540
541 self.state_modules = []
543 def step(self, batch: Any, batch_idx: BatchIndex):
544 self.encoder.train(self.mode.is_train)
545 self.decoder.train(self.mode.is_train)
Move data
and mask
to device and swap the sequence and batch dimensions. data
will have shape [seq_len, batch_size, 5]
and mask
will have shape [seq_len, batch_size]
.
550 data = batch[0].to(self.device).transpose(0, 1)
551 mask = batch[1].to(self.device).transpose(0, 1)
Increment step in training mode
554 if self.mode.is_train:
555 tracker.add_global_step(len(data))
Encode the sequence of strokes
558 with monit.section("encoder"):
Get , , and
560 z, mu, sigma_hat = self.encoder(data)
Decode the mixture of distributions and
563 with monit.section("decoder"):
Concatenate
565 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566 inputs = torch.cat([data[:-1], z_stack], 2)
Get mixture of distributions and
568 dist, q_logits, _ = self.decoder(inputs, z, None)
Compute the loss
571 with monit.section('loss'):
573 kl_loss = self.kl_div_loss(sigma_hat, mu)
575 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
577 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
Track losses
580 tracker.add("loss.kl.", kl_loss)
581 tracker.add("loss.reconstruction.", reconstruction_loss)
582 tracker.add("loss.total.", loss)
Only if we are in training state
585 if self.mode.is_train:
Run optimizer
587 with monit.section('optimize'):
Set grad
to zero
589 self.optimizer.zero_grad()
Compute gradients
591 loss.backward()
Log model parameters and gradients
593 if batch_idx.is_last:
594 tracker.add(encoder=self.encoder, decoder=self.decoder)
Clip gradients
596 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
Optimize
599 self.optimizer.step()
600
601 tracker.save()
603 def sample(self):
Randomly pick a sample from validation dataset to encoder
605 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
Add batch dimension and move it to device
607 data = data.unsqueeze(1).to(self.device)
Sample
609 self.sampler.sample(data, self.temperature)
612def main():
613 configs = Configs()
614 experiment.create(name="sketch_rnn")
Pass a dictionary of configurations
617 experiment.configs(configs, {
618 'optimizer.optimizer': 'Adam',
We use a learning rate of 1e-3
because we can see results faster. Paper had suggested 1e-4
.
621 'optimizer.learning_rate': 1e-3,
Name of the dataset
623 'dataset_name': 'bicycle',
Number of inner iterations within an epoch to switch between training, validation and sampling.
625 'inner_iterations': 10
626 })
627
628 with experiment.start():
Run the experiment
630 configs.run()
631
632
633if __name__ == "__main__":
634 main()