这是论文《素描绘画的神经表示》的带注释的 PyTorch 实现。
Sketch RNN 是一种序列到序列的变分自动编码器。编码器和解码器都是循环神经网络模型。它通过预测一系列笔画来学习重建基于笔触的简单绘画。解码器将每个笔划预测为高斯笔画的混合。
从 Quick,Draw! 下载数据数据集。自述npz
文件的 Sk etch-RNN QuickDraw 数据集部分有一个下载文件的链接。将下载的npz
文件放在data/sketch
文件夹中。此代码配置为使用bicycle
数据集。你可以在配置中更改此设置。
从 A lexis David Jacq 的 PyTorch Sketch RNN 项目
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
50class StrokesDataset(Dataset):
dataset
是形状为 seq_len 的 numpy 数组的列表,3。它是一个笔画序列,每个笔画由 3 个整数表示。前两个是沿 x 和 y (,) 的位移,最后一个整数表示笔的状态,如果它接触纸张和否则。
57 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67 data = []
我们遍历每个序列并进行过滤
69 for seq in dataset:
筛选笔画序列的长度是否在我们的范围内
71 if 10 < len(seq) <= max_seq_length:
Clamp,到
73 seq = np.minimum(seq, 1000)
74 seq = np.maximum(seq, -1000)
转换为浮点数组并添加到data
76 seq = np.array(seq, dtype=np.float32)
77 data.append(seq)
然后,我们计算缩放系数,即 (,) 组合的标准差。论文指出,为了简单起见,均值没有进行调整,因为均值无论如何都接近。
83 if scale is None:
84 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85 self.scale = scale
获取所有序列中最长的序列长度
88 longest_seq_len = max([len(seq) for seq in data])
我们在初始化 PyTorch 数据数组时添加了两个额外的步骤,分别是序列开始 (sos) 和序列结束 (eos)。每一步都是一个向量。只有一个是,其他的是。它们按该顺序表示笔向下、向上和序列结尾。是如果笔在下一步中碰到纸张。如果在下一步中笔没有碰到纸张。如果是绘图的结尾。
98 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
掩码数组只需要一个额外的步骤,因为它是用于解码器的输出,解码器接收data[:-1]
并预测下一步。
101 self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103 for i, seq in enumerate(data):
104 seq = torch.from_numpy(seq)
105 len_seq = len(seq)
缩放和设置
107 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
109 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
111 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
113 self.data[i, len_seq + 1:, 4] = 1
遮罩开启直到序列结束
115 self.mask[i, :len_seq + 1] = 1
序列开头是
118 self.data[:, 0, 2] = 1
数据集的大小
120 def __len__(self):
122 return len(self.data)
获取样品
124 def __getitem__(self, idx: int):
126 return self.data[idx], self.mask[idx]
129class BivariateGaussianMixture:
139 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141 self.pi_logits = pi_logits
142 self.mu_x = mu_x
143 self.mu_y = mu_y
144 self.sigma_x = sigma_x
145 self.sigma_y = sigma_y
146 self.rho_xy = rho_xy
混合物中的分布数,
148 @property
149 def n_distributions(self):
151 return self.pi_logits.shape[-1]
按温度调整
153 def set_temperature(self, temperature: float):
158 self.pi_logits /= temperature
160 self.sigma_x *= math.sqrt(temperature)
162 self.sigma_y *= math.sqrt(temperature)
164 def get_distribution(self):
Clamp,为了避免得到NaN
s
166 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
获取手段
171 mean = torch.stack([self.mu_x, self.mu_y], -1)
获取协方差矩阵
173 cov = torch.stack([
174 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176 ], -1)
177 cov = cov.view(*sigma_y.shape, 2, 2)
创建双变量正态分布。
📝 将scale_tril
矩阵作为[[a, 0], [b, c]]
何处进行矩阵会很有效。但为了简单起见,我们使用协方差矩阵。如果你想阅读更多关于双变量分布、它们的协方差矩阵和概率密度函数的信息,这是一个很好的资源。
188 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
根据对数创建分类分布
191 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194 return cat_dist, multi_dist
197class EncoderRNN(Module):
204 def __init__(self, d_z: int, enc_hidden_size: int):
205 super().__init__()
创建双向 LSTM,将序列作为输入。
208 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
去得到
210 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
去得到
212 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214 def forward(self, inputs: torch.Tensor, state=None):
双向LSTM的隐藏状态是向前方向的最后一个令牌的输出和相反方向的第一个令牌的输出串联,这正是我们想要的。
221 _, (hidden, cell) = self.lstm(inputs.float(), state)
状态具有形状[2, batch_size, hidden_size]
,其中第一个维度是方向。我们重新排列它来获得
225 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
228 mu = self.mu_head(hidden)
230 sigma_hat = self.sigma_head(hidden)
232 sigma = torch.exp(sigma_hat / 2.)
样本
235 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
238 return z, mu, sigma_hat
241class DecoderRNN(Module):
248 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249 super().__init__()
LSTM 将作为输入
251 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
LSTM 的初始状态为。init_state
是这个的线性变换
255 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
该层为每个生成输出n_distributions
。每个分布需要六个参数
260 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
这个头是为 logits 准备的
263 self.q_head = nn.Linear(dec_hidden_size, 3)
这是为了计算在哪里
266 self.q_log_softmax = nn.LogSoftmax(-1)
这些参数存储起来以备将来参考
269 self.n_distributions = n_distributions
270 self.dec_hidden_size = dec_hidden_size
272 def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
计算初始状态
274 if state is None:
276 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
并c
有形状[batch_size, lstm_size]
。我们想要将它们塑造成形状,[1, batch_size, lstm_size]
因为这就是 LSTM 中使用的形状。
279 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
运行 LSTM
282 outputs, state = self.lstm(x, state)
得到
285 q_logits = self.q_log_softmax(self.q_head(outputs))
得到。torch.split
将输出拆分为self.n_distribution
跨维度大小的 6 个张量2
。
291 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292 torch.split(self.mixtures(outputs), self.n_distributions, 2)
305 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
309 return dist, q_logits, state
312class ReconstructionLoss(Module):
317 def forward(self, mask: torch.Tensor, target: torch.Tensor,
318 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
获取和
320 pi, mix = dist.get_distribution()
target
具有形状,[seq_len, batch_size, 5]
其中最后一个维度是要素。我们想要得到 y 并从混合中的每个分布中获得概率。
xy
会有形状[seq_len, batch_size, n_distributions, 2]
327 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
计算概率
333 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
虽然probs
has (longest_seq_len
) 元素,但总和只被占用,因为其余的都是掩盖了。
可能感觉我们应该取总和然后除以而不是,但这将为较短序列中的单个预测提供更高的权重。当我们除以时,我们对每个预测赋予相等的权重
342 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
345 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
348 return loss_stroke + loss_pen
351class KLDivLoss(Module):
358 def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
360 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
363class Sampler:
370 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371 self.decoder = decoder
372 self.encoder = encoder
374 def sample(self, data: torch.Tensor, temperature: float):
376 longest_seq_len = len(data)
从编码器获取
379 z, _, _ = self.encoder(data)
序列起始行程为
382 s = data.new_tensor([0, 0, 1, 0, 0])
383 seq = [s]
初始解码器是None
。解码器会将其初始化为
386 state = None
我们不需要渐变
389 with torch.no_grad():
笔画样本
391 for i in range(longest_seq_len):
是解码器的输入
393 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
从解码器获取、和下一个状态
396 dist, q_logits, state = self.decoder(data, z, state)
对笔画进行采样
398 s = self._sample_step(dist, q_logits, temperature)
将新笔划添加到笔画序列中
400 seq.append(s)
如果停止采样。这表示草绘已停止
402 if s[4] == 1:
403 break
创建笔画序列的 PyTorch 张量
406 seq = torch.stack(seq)
绘制笔画顺序
409 self.plot(seq)
411 @staticmethod
412 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
设定采样温度。这是在课堂上实现的BivariateGaussianMixture
。
414 dist.set_temperature(temperature)
调整温度和
416 pi, mix = dist.get_distribution()
从混合物中使用的分布指数中的样品
418 idx = pi.sample()[0, 0]
使用对数概率创建类别分布q_logits
或
421 q = torch.distributions.Categorical(logits=q_logits / temperature)
样本来自
423 q_idx = q.sample()[0, 0]
从混合物中的正态分布中取样,然后选取索引的正态分布idx
426 xy = mix.sample()[0, 0, idx]
创建空描边
429 stroke = q_logits.new_zeros(5)
设置
431 stroke[:2] = xy
设置
433 stroke[q_idx + 2] = 1
435 return stroke
437 @staticmethod
438 def plot(seq: torch.Tensor):
取的累计总和得到
440 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
创建一个格式为 numpy 的新 numpy 数组
442 seq[:, 2] = seq[:, 3]
443 seq = seq[:, 0:3].detach().cpu().numpy()
将数组拆分为 where is。即在笔从纸张上抬起的点分割笔划数组。这给出了笔画序列的列表。
448 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
绘制每个笔画序列
450 for s in strokes:
451 plt.plot(s[:, 0], -s[:, 1])
不要显示轴
453 plt.axis('off')
显示剧情
455 plt.show()
458class Configs(TrainValidConfigs):
用于选择要运行实验的设备的设备配置
466 device: torch.device = DeviceConfigs()
468 encoder: EncoderRNN
469 decoder: DecoderRNN
470 optimizer: optim.Adam
471 sampler: Sampler
472
473 dataset_name: str
474 train_loader: DataLoader
475 valid_loader: DataLoader
476 train_dataset: StrokesDataset
477 valid_dataset: StrokesDataset
编码器和解码器尺寸
480 enc_hidden_size = 256
481 dec_hidden_size = 512
批量大小
484 batch_size = 100
中的要素数量
487 d_z = 128
混合物中的分布数,
489 n_distributions = 20
KL 背离损失的权重,
492 kl_div_loss_weight = 0.5
渐变剪切
494 grad_clip = 1.
采样温度
496 temperature = 0.4
筛选出长度大于
499 max_seq_length = 200
500
501 epochs = 100
502
503 kl_div_loss = KLDivLoss()
504 reconstruction_loss = ReconstructionLoss()
506 def init(self):
初始化编码器和解码器
508 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
设置优化器。优化器类型和学习率等内容是可配置的
512 optimizer = OptimizerConfigs()
513 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514 self.optimizer = optimizer
创建采样器
517 self.sampler = Sampler(self.encoder, self.decoder)
npz
文件路径是data/sketch/[DATASET NAME].npz
520 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
加载那个 numpy 文件
522 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
创建训练数据集
525 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
创建验证数据集
527 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
创建训练数据加载器
530 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
创建验证数据加载器
532 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
添加挂钩以监控 Tensorboard 上的图层输出
535 hook_model_outputs(self.mode, self.encoder, 'encoder')
536 hook_model_outputs(self.mode, self.decoder, 'decoder')
配置跟踪器以打印总训练/验证损失
539 tracker.set_scalar("loss.total.*", True)
540
541 self.state_modules = []
543 def step(self, batch: Any, batch_idx: BatchIndex):
544 self.encoder.train(self.mode.is_train)
545 self.decoder.train(self.mode.is_train)
将data
和移mask
至设备并交换序列和批次维度。data
将有形状[seq_len, batch_size, 5]
和mask
形状[seq_len, batch_size]
。
550 data = batch[0].to(self.device).transpose(0, 1)
551 mask = batch[1].to(self.device).transpose(0, 1)
在训练模式中增加步数
554 if self.mode.is_train:
555 tracker.add_global_step(len(data))
对笔画序列进行编码
558 with monit.section("encoder"):
获取、和
560 z, mu, sigma_hat = self.encoder(data)
解码分布和分布的混合
563 with monit.section("decoder"):
串联
565 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566 inputs = torch.cat([data[:-1], z_stack], 2)
混合分布和
568 dist, q_logits, _ = self.decoder(inputs, z, None)
计算损失
571 with monit.section('loss'):
573 kl_loss = self.kl_div_loss(sigma_hat, mu)
575 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
577 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
追踪损失
580 tracker.add("loss.kl.", kl_loss)
581 tracker.add("loss.reconstruction.", reconstruction_loss)
582 tracker.add("loss.total.", loss)
只有当我们处于训练状态时
585 if self.mode.is_train:
运行优化器
587 with monit.section('optimize'):
设grad
为零
589 self.optimizer.zero_grad()
计算梯度
591 loss.backward()
记录模型参数和梯度
593 if batch_idx.is_last:
594 tracker.add(encoder=self.encoder, decoder=self.decoder)
剪辑渐变
596 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
优化
599 self.optimizer.step()
600
601 tracker.save()
603 def sample(self):
从验证数据集中随机选择一个样本到编码器
605 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
添加批量维度并将其移至设备
607 data = data.unsqueeze(1).to(self.device)
样本
609 self.sampler.sample(data, self.temperature)
612def main():
613 configs = Configs()
614 experiment.create(name="sketch_rnn")
传递配置字典
617 experiment.configs(configs, {
618 'optimizer.optimizer': 'Adam',
我们使用学习速率1e-3
是因为我们可以更快地看到结果。Paper 曾暗示过1e-4
。
621 'optimizer.learning_rate': 1e-3,
数据集的名称
623 'dataset_name': 'bicycle',
一个纪元内要在训练、验证和采样之间切换的内部迭代次数。
625 'inner_iterations': 10
626 })
627
628 with experiment.start():
运行实验
630 configs.run()
631
632
633if __name__ == "__main__":
634 main()