素描 RNN

这是论文《素描绘画的神经表示》的带注释的 PyTorch 实现。

Sketch RNN 是一种序列到序列的变分自动编码器。编码器和解码器都是循环神经网络模型。它通过预测一系列笔画来学习重建基于笔触的简单绘画。解码器将每个笔划预测为高斯笔画的混合。

获取数据

Quick,Draw! 下载数据数据集。自述npz 文件的 Sk etch-RNN QuickDraw 数据集部分有一个下载文件的链接。将下载的npz 文件放在data/sketch 文件夹中。此代码配置为使用bicycle 数据集。你可以在配置中更改此设置。

致谢

从 A lexis David JacqPyTorch Sketch RNN 项目

32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex

数据集

此类加载和预处理数据。

50class StrokesDataset(Dataset):

dataset 是形状为 seq_len 的 numpy 数组的列表,3。它是一个笔画序列,每个笔画由 3 个整数表示。前两个是沿 x 和 y (,) 的位移,最后一个整数表示笔的状态,如果它接触纸张和否则。

57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67        data = []

我们遍历每个序列并进行过滤

69        for seq in dataset:

筛选笔画序列的长度是否在我们的范围内

71            if 10 < len(seq) <= max_seq_length:

Clamp

73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)

转换为浮点数组并添加到data

76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)

然后,我们计算缩放系数,即 (,) 组合的标准差。论文指出,为了简单起见,均值没有进行调整,因为均值无论如何都接近

83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scale

获取所有序列中最长的序列长度

88        longest_seq_len = max([len(seq) for seq in data])

我们在初始化 PyTorch 数据数组时添加了两个额外的步骤,分别是序列开始 (sos) 和序列结束 (eos)。每一步都是一个向量。只有一个,其他的是。它们按该顺序表示笔向下、向序列结尾如果笔在下一步中碰到纸张。如果在下一步中笔没有碰到纸张。如果是绘图的结尾。

98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

掩码数组只需要一个额外的步骤,因为它是用于解码器的输出,解码器接收data[:-1] 并预测下一步。

101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)

缩放和设置

107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

113            self.data[i, len_seq + 1:, 4] = 1

遮罩开启直到序列结束

115            self.mask[i, :len_seq + 1] = 1

序列开头是

118        self.data[:, 0, 2] = 1

数据集的大小

120    def __len__(self):
122        return len(self.data)

获取样品

124    def __getitem__(self, idx: int):
126        return self.data[idx], self.mask[idx]

双变量高斯混合

混合物用和表示。此类调整温度,并根据参数创建分类分布和高斯分布。

129class BivariateGaussianMixture:
139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xy

混合物中的分布数,

148    @property
149    def n_distributions(self):
151        return self.pi_logits.shape[-1]

按温度调整

153    def set_temperature(self, temperature: float):

158        self.pi_logits /= temperature

160        self.sigma_x *= math.sqrt(temperature)

162        self.sigma_y *= math.sqrt(temperature)
164    def get_distribution(self):

Clamp为了避免得到NaN s

166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

获取手段

171        mean = torch.stack([self.mu_x, self.mu_y], -1)

获取协方差矩阵

173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)

创建双变量正态分布。

📝 将scale_tril 矩阵作为[[a, 0], [b, c]] 何处进行矩阵会很有效。但为了简单起见,我们使用协方差矩阵。如果你想阅读更多关于双变量分布、它们的协方差矩阵和概率密度函数的信息,这是一个很好的资源

188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

根据对数创建分类分布

191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)

194        return cat_dist, multi_dist

编码器模块

这包括一个双向 LSTM

197class EncoderRNN(Module):
204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()

创建双向 LSTM,将序列作为输入。

208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

去得到

210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

去得到

212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214    def forward(self, inputs: torch.Tensor, state=None):

双向LSTM的隐藏状态是向前方向的最后一个令牌的输出和相反方向的第一个令牌的输出串联,这正是我们想要的。

221        _, (hidden, cell) = self.lstm(inputs.float(), state)

状态具有形状[2, batch_size, hidden_size] ,其中第一个维度是方向。我们重新排列它来获得

225        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

228        mu = self.mu_head(hidden)

230        sigma_hat = self.sigma_head(hidden)

232        sigma = torch.exp(sigma_hat / 2.)

样本

235        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))

238        return z, mu, sigma_hat

解码器模块

它由一个 LSTM 组成

241class DecoderRNN(Module):
248    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249        super().__init__()

LSTM 将作为输入

251        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

LSTM 的初始状态为init_state 是这个的线性变换

255        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

该层为每个生成输出n_distributions 。每个分布需要六个参数

260        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

这个头是为 logits 准备的

263        self.q_head = nn.Linear(dec_hidden_size, 3)

这是为了计算在哪里

266        self.q_log_softmax = nn.LogSoftmax(-1)

这些参数存储起来以备将来参考

269        self.n_distributions = n_distributions
270        self.dec_hidden_size = dec_hidden_size
272    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

计算初始状态

274        if state is None:

276            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

hc 有形状[batch_size, lstm_size] 。我们想要将它们塑造成形状,[1, batch_size, lstm_size] 因为这就是 LSTM 中使用的形状。

279            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

运行 LSTM

282        outputs, state = self.lstm(x, state)

得到

285        q_logits = self.q_log_softmax(self.q_head(outputs))

得到torch.split 将输出拆分为self.n_distribution 跨维度大小的 6 个张量2

291        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292            torch.split(self.mixtures(outputs), self.n_distributions, 2)

创建双变量高斯混合在哪里

是从混合物中选择分布的绝对概率

305        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))

309        return dist, q_logits, state

重建损失

312class ReconstructionLoss(Module):
317    def forward(self, mask: torch.Tensor, target: torch.Tensor,
318                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

获取

320        pi, mix = dist.get_distribution()

target 具有形状,[seq_len, batch_size, 5] 其中最后一个维度是要素。我们想要得到 y 并从混合中的每个分布中获得概率

xy 会有形状[seq_len, batch_size, n_distributions, 2]

327        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

计算概率

333        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

虽然probs has (longest_seq_len ) 元素,但总和只被占用,因为其余的都是掩盖了。

可能感觉我们应该取总和然后除以而不是,但这将为较短序列中的单个预测提供更高的权重。当我们除以时,我们对每个预测赋予相等的权重

342        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

345        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

348        return loss_stroke + loss_pen

KL-背离损失

这将计算给定正态分布与

351class KLDivLoss(Module):
358    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

360        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

采样器

这将从解码器中采样草图并绘制出来

363class Sampler:
370    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371        self.decoder = decoder
372        self.encoder = encoder
374    def sample(self, data: torch.Tensor, temperature: float):

376        longest_seq_len = len(data)

从编码器获取

379        z, _, _ = self.encoder(data)

序列起始行程为

382        s = data.new_tensor([0, 0, 1, 0, 0])
383        seq = [s]

初始解码器是None 。解码器会将其初始化为

386        state = None

我们不需要渐变

389        with torch.no_grad():

笔画样本

391            for i in range(longest_seq_len):

是解码器的输入

393                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

从解码器获取和下一个状态

396                dist, q_logits, state = self.decoder(data, z, state)

对笔画进行采样

398                s = self._sample_step(dist, q_logits, temperature)

将新笔划添加到笔画序列中

400                seq.append(s)

如果停止采样。这表示草绘已停止

402                if s[4] == 1:
403                    break

创建笔画序列的 PyTorch 张量

406        seq = torch.stack(seq)

绘制笔画顺序

409        self.plot(seq)
411    @staticmethod
412    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

设定采样温度。这是在课堂上实现的BivariateGaussianMixture

414        dist.set_temperature(temperature)

调整温度

416        pi, mix = dist.get_distribution()

从混合物中使用的分布指数中的样品

418        idx = pi.sample()[0, 0]

使用对数概率创建类别分布q_logits

421        q = torch.distributions.Categorical(logits=q_logits / temperature)

样本来自

423        q_idx = q.sample()[0, 0]

从混合物中的正态分布中取样,然后选取索引的正态分布idx

426        xy = mix.sample()[0, 0, idx]

创建空描边

429        stroke = q_logits.new_zeros(5)

设置

431        stroke[:2] = xy

设置

433        stroke[q_idx + 2] = 1

435        return stroke
437    @staticmethod
438    def plot(seq: torch.Tensor):

取的累计总和得

440        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

创建一个格式为 numpy 的新 numpy 数组

442        seq[:, 2] = seq[:, 3]
443        seq = seq[:, 0:3].detach().cpu().numpy()

将数组拆分为 where is。即在笔从纸张上抬起的点分割笔划数组。这给出了笔画序列的列表。

448        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

绘制每个笔画序列

450        for s in strokes:
451            plt.plot(s[:, 0], -s[:, 1])

不要显示轴

453        plt.axis('off')

显示剧情

455        plt.show()

配置

这些是默认配置,稍后可以通过传入 a 进行调整dict

458class Configs(TrainValidConfigs):

用于选择要运行实验的设备的设备配置

466    device: torch.device = DeviceConfigs()

468    encoder: EncoderRNN
469    decoder: DecoderRNN
470    optimizer: optim.Adam
471    sampler: Sampler
472
473    dataset_name: str
474    train_loader: DataLoader
475    valid_loader: DataLoader
476    train_dataset: StrokesDataset
477    valid_dataset: StrokesDataset

编码器和解码器尺寸

480    enc_hidden_size = 256
481    dec_hidden_size = 512

批量大小

484    batch_size = 100

中的要素数量

487    d_z = 128

混合物中的分布数,

489    n_distributions = 20

KL 背离损失的权重,

492    kl_div_loss_weight = 0.5

渐变剪切

494    grad_clip = 1.

采样温度

496    temperature = 0.4

筛选出长度大于

499    max_seq_length = 200
500
501    epochs = 100
502
503    kl_div_loss = KLDivLoss()
504    reconstruction_loss = ReconstructionLoss()
506    def init(self):

初始化编码器和解码器

508        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

设置优化器。优化器类型和学习率等内容是可配置的

512        optimizer = OptimizerConfigs()
513        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514        self.optimizer = optimizer

创建采样器

517        self.sampler = Sampler(self.encoder, self.decoder)

npz 文件路径是data/sketch/[DATASET NAME].npz

520        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

加载那个 numpy 文件

522        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

创建训练数据集

525        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

创建验证数据集

527        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

创建训练数据加载器

530        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

创建验证数据加载器

532        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

添加挂钩以监控 Tensorboard 上的图层输出

535        hook_model_outputs(self.mode, self.encoder, 'encoder')
536        hook_model_outputs(self.mode, self.decoder, 'decoder')

配置跟踪器以打印总训练/验证损失

539        tracker.set_scalar("loss.total.*", True)
540
541        self.state_modules = []
543    def step(self, batch: Any, batch_idx: BatchIndex):
544        self.encoder.train(self.mode.is_train)
545        self.decoder.train(self.mode.is_train)

data 和移mask 至设备并交换序列和批次维度。data 将有形状[seq_len, batch_size, 5]mask 形状[seq_len, batch_size]

550        data = batch[0].to(self.device).transpose(0, 1)
551        mask = batch[1].to(self.device).transpose(0, 1)

在训练模式中增加步数

554        if self.mode.is_train:
555            tracker.add_global_step(len(data))

对笔画序列进行编码

558        with monit.section("encoder"):

获取、和

560            z, mu, sigma_hat = self.encoder(data)

解码分布和分布的混合

563        with monit.section("decoder"):

串联

565            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566            inputs = torch.cat([data[:-1], z_stack], 2)

混合分布和

568            dist, q_logits, _ = self.decoder(inputs, z, None)

计算损失

571        with monit.section('loss'):

573            kl_loss = self.kl_div_loss(sigma_hat, mu)

575            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

577            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

追踪损失

580            tracker.add("loss.kl.", kl_loss)
581            tracker.add("loss.reconstruction.", reconstruction_loss)
582            tracker.add("loss.total.", loss)

只有当我们处于训练状态时

585        if self.mode.is_train:

运行优化器

587            with monit.section('optimize'):

grad 为零

589                self.optimizer.zero_grad()

计算梯度

591                loss.backward()

记录模型参数和梯度

593                if batch_idx.is_last:
594                    tracker.add(encoder=self.encoder, decoder=self.decoder)

剪辑渐变

596                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

优化

599                self.optimizer.step()
600
601        tracker.save()
603    def sample(self):

从验证数据集中随机选择一个样本到编码器

605        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

添加批量维度并将其移至设备

607        data = data.unsqueeze(1).to(self.device)

样本

609        self.sampler.sample(data, self.temperature)
612def main():
613    configs = Configs()
614    experiment.create(name="sketch_rnn")

传递配置字典

617    experiment.configs(configs, {
618        'optimizer.optimizer': 'Adam',

我们使用学习速率1e-3 是因为我们可以更快地看到结果。Paper 曾暗示过1e-4

621        'optimizer.learning_rate': 1e-3,

数据集的名称

623        'dataset_name': 'bicycle',

一个纪元内要在训练、验证和采样之间切换的内部迭代次数。

625        'inner_iterations': 10
626    })
627
628    with experiment.start():

运行实验

630        configs.run()
631
632
633if __name__ == "__main__":
634    main()