スケッチ RNN

これは、論文「スケッチ図面のニューラル表現」の注釈付きPyTorch実装です

Sketch RNN はシーケンス間の変分オートエンコーダーです。エンコーダーとデコーダーはどちらもリカレントニューラルネットワークモデルです。一連のストロークを予測することで、ストロークに基づいて簡単な図面を再構築する方法を学習します。デコーダーは、各ストロークをガウスの混合として予測します

データ取得

クイック、ドロー!からデータをダウンロードデータセットreadmeのSketch-RNN QuickDraw npz Datasetセクションにファイルをダウンロードするためのリンクがありますnpz data/sketch ダウンロードしたファイルをフォルダに配置します。bicycle このコードはデータセットを使用するように構成されています。これは設定で変更できます。

謝辞

アレクシス・デイヴィッド・ジャックによるPyTorch Sketch RNNNプロジェクトの協力を得ました

32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex

データセット

このクラスは、データをロードして前処理します。

50class StrokesDataset(Dataset):

dataset seq_len, 3 という形状のゴツゴツした配列のリストです。これは一連のストロークで、各ストロークは3つの整数で表されます。最初の 2 つは x と y に沿った変位 (,) で、最後の整数はペンの状態 (紙に触れている場合とそうでない場合) を表します

57    def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67        data = []

各シーケンスとフィルターを繰り返し処理します。

69        for seq in dataset:

ストロークのシーケンスの長さが範囲内にある場合はフィルタリングします

71            if 10 < len(seq) <= max_seq_length:

クランプ

73                seq = np.minimum(seq, 1000)
74                seq = np.maximum(seq, -1000)

浮動小数点配列に変換して追加 data

76                seq = np.array(seq, dtype=np.float32)
77                data.append(seq)

次に、(,) を組み合わせた標準偏差であるスケーリング係数を計算します。論文によると、平均はとにかくそれに近いので、簡単にするために平均を調整していないという

83        if scale is None:
84            scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85        self.scale = scale

すべてのシーケンスの中で一番長いシーケンス長を取得

88        longest_seq_len = max([len(seq) for seq in data])

PyTorchデータ配列を初期化するには、シーケンスの開始 (sos) とシーケンスの終了 (eos) の2つの追加ステップが必要です。各ステップはベクトルですそのうちの1つだけがそうで、他はそうです。ペンダウン、ペンアップシーケンスの終了の順序で表されます次のステップでペンが紙に触れた場合です。次のステップでペンが紙に触れない場合です。それがドローイングの終わりだとしたら

98        self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)

マスク配列は、次のステップを取り込んで予測するデコーダーの出力用なので、追加のステップは1つだけ必要です。data[:-1]

101        self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103        for i, seq in enumerate(data):
104            seq = torch.from_numpy(seq)
105            len_seq = len(seq)

スケールとセット

107            self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale

109            self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]

111            self.data[i, 1:len_seq + 1, 3] = seq[:, 2]

113            self.data[i, len_seq + 1:, 4] = 1

マスクはシーケンスの終わりまでオンです

115            self.mask[i, :len_seq + 1] = 1

シーケンスの開始は

118        self.data[:, 0, 2] = 1

データセットのサイズ

120    def __len__(self):
122        return len(self.data)

サンプルを入手

124    def __getitem__(self, idx: int):
126        return self.data[idx], self.mask[idx]

2 変量ガウス混合物

混合物はおよびで表されます。このクラスは温度を調整し、パラメーターからカテゴリ分布とガウス分布を作成します

129class BivariateGaussianMixture:
139    def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140                 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141        self.pi_logits = pi_logits
142        self.mu_x = mu_x
143        self.mu_y = mu_y
144        self.sigma_x = sigma_x
145        self.sigma_y = sigma_y
146        self.rho_xy = rho_xy

混合物中の分布の数、

148    @property
149    def n_distributions(self):
151        return self.pi_logits.shape[-1]

温度による調整

153    def set_temperature(self, temperature: float):

158        self.pi_logits /= temperature

160        self.sigma_x *= math.sqrt(temperature)

162        self.sigma_y *= math.sqrt(temperature)
164    def get_distribution(self):

クランプ NaN Sが入らないように

166        sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167        sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168        rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)

手段を取得

171        mean = torch.stack([self.mu_x, self.mu_y], -1)

共分散行列を取得

173        cov = torch.stack([
174            sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175            rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176        ], -1)
177        cov = cov.view(*sigma_y.shape, 2, 2)

2 変量正規分布を作成します。

📝 where scale_tril をマトリックス化すると効率的です[[a, 0], [b, c]] ただし、わかりやすくするために、共分散マトリックスを使用します。これは、二変量分布、それらの共分散行列、および確率密度関数について詳しく知りたい場合に役立つリソースです

188        multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)

ロジットからカテゴリ分布を作成

191        cat_dist = torch.distributions.Categorical(logits=self.pi_logits)

194        return cat_dist, multi_dist

エンコーダモジュール

これは双方向の LSTM で構成されています。

197class EncoderRNN(Module):
204    def __init__(self, d_z: int, enc_hidden_size: int):
205        super().__init__()

のシーケンスを入力として双方向 LSTM を作成します。

208        self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)

さっそくゲットしよう

210        self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)

さっそくゲットしよう

212        self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214    def forward(self, inputs: torch.Tensor, state=None):

双方向 LSTM の隠れ状態は、最後のトークンの出力を順方向に、最初のトークンの出力を逆方向に連結したものです。これが私たちが望むことです。

221        _, (hidden, cell) = self.lstm(inputs.float(), state)

状態には形があり[2, batch_size, hidden_size] 、最初の次元が方向です。それを再配置して

225        hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')

228        mu = self.mu_head(hidden)

230        sigma_hat = self.sigma_head(hidden)

232        sigma = torch.exp(sigma_hat / 2.)

[サンプル]

235        z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))

238        return z, mu, sigma_hat

デコーダモジュール

これはLSTMで構成されています

241class DecoderRNN(Module):
248    def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249        super().__init__()

LSTM は入力として取ります

251        self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)

LSTM の初期状態はです。init_state これの線形変換は

255        self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)

このレイヤーは、それぞれの出力を生成しますn_distributions 。各分布には 6 つのパラメータが必要です

260        self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)

このヘッドはロジット用です

263        self.q_head = nn.Linear(dec_hidden_size, 3)

これは場所を計算するためです

266        self.q_log_softmax = nn.LogSoftmax(-1)

これらのパラメータは、後で参照できるように保存されます

269        self.n_distributions = n_distributions
270        self.dec_hidden_size = dec_hidden_size
272    def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):

初期状態の計算

274        if state is None:

276            h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)

h c そして形があります[batch_size, lstm_size][1, batch_size, lstm_size] LSTMで使われている形なので、形を整えたいのです

279            state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())

LSTM を実行してください

282        outputs, state = self.lstm(x, state)

取得

285        q_logits = self.q_log_softmax(self.q_head(outputs))

取得torch.split self.n_distribution 出力を次元全体のサイズの6つのテンソルに分割します

2
291        pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292            torch.split(self.mixtures(outputs), self.n_distributions, 2)

2 変量ガウス混合物の作成と場所と

は、混合から分布を選択するカテゴリ確率です。

305        dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306                                        torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))

309        return dist, q_logits, state

復興損失

312class ReconstructionLoss(Module):
317    def forward(self, mask: torch.Tensor, target: torch.Tensor,
318                 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):

取得して

320        pi, mix = dist.get_distribution()

target [seq_len, batch_size, 5] 最後の次元がフィーチャであるような形をしていますy を取得して、混合内の各分布から確率を求めたいと思います

xy 形になります [seq_len, batch_size, n_distributions, 2]

327        xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)

確率の計算

333        probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)

probs 要素には (longest_seq_len ) がありますが、残りはマスクされているので、合計が計算されるだけです。

合計を取って除算しないで割る必要があるように感じるかもしれませんが、これによって、短いシーケンスで個々の予測をより重要視できるようになります。で割ると、各予測に同じ重みを与えます

342        loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))

345        loss_pen = -torch.mean(target[:, :, 2:] * q_logits)

348        return loss_stroke + loss_pen

KL-ダイバージェンスロス

これは、特定の正規分布との KL ダイバージェンスを計算します。

351class KLDivLoss(Module):
358    def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):

360        return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))

サンプラー

デコーダーからスケッチをサンプリングしてプロットします。

363class Sampler:
370    def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371        self.decoder = decoder
372        self.encoder = encoder
374    def sample(self, data: torch.Tensor, temperature: float):

376        longest_seq_len = len(data)

エンコーダから取得

379        z, _, _ = self.encoder(data)

シーケンスの開始ストロークは

382        s = data.new_tensor([0, 0, 1, 0, 0])
383        seq = [s]

初期デコーダーはNone .デコーダーはそれを次のように初期化します

386        state = None

グラデーションはいらない

389        with torch.no_grad():

サンプルストローク

391            for i in range(longest_seq_len):

デコーダーへの入力です

393                data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)

デコーダーから、、および次の状態を取得

396                dist, q_logits, state = self.decoder(data, z, state)

ストロークをサンプリングする

398                s = self._sample_step(dist, q_logits, temperature)

新しいストロークをストロークのシーケンスに追加します

400                seq.append(s)

もしそうなら、サンプリングを停止してください。これはスケッチが停止したことを示します

402                if s[4] == 1:
403                    break

一連のストロークの PyTorch テンソルを作成します。

406        seq = torch.stack(seq)

ストロークのシーケンスをプロット

409        self.plot(seq)
411    @staticmethod
412    def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):

サンプリングの温度を設定します。これはクラスで実装されていますBivariateGaussianMixture

414        dist.set_temperature(temperature)

温度調整して

416        pi, mix = dist.get_distribution()

分布の指標からサンプルを採取して、混合物から使用する

418        idx = pi.sample()[0, 0]

対数確率によるカテゴリ分布の作成または q_logits

421        q = torch.distributions.Categorical(logits=q_logits / temperature)

からのサンプル

423        q_idx = q.sample()[0, 0]

混合物の正規分布からサンプリングし、次の式でインデックス付けされた分布を選択します。idx

426        xy = mix.sample()[0, 0, idx]

空のストロークを作成

429        stroke = q_logits.new_zeros(5)

セット

431        stroke[:2] = xy

セット

433        stroke[q_idx + 2] = 1

435        return stroke
437    @staticmethod
438    def plot(seq: torch.Tensor):

の累積和を取ると、

440        seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)

次の形式の新しいnumpy配列を作成します

442        seq[:, 2] = seq[:, 3]
443        seq = seq[:, 0:3].detach().cpu().numpy()

配列をあるポイントで分割します。つまり、ペンを紙から持ち上げるポイントでストロークの配列を分割します。これにより、ストロークのシーケンスのリストが表示されます

448        strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)

ストロークの各シーケンスをプロット

450        for s in strokes:
451            plt.plot(s[:, 0], -s[:, 1])

座標軸を表示しないでください

453        plt.axis('off')

プロットを表示

455        plt.show()

コンフィギュレーション

これらはデフォルトの設定で、を渡すことで後で調整できますdict

458class Configs(TrainValidConfigs):

実験を実行するデバイスを選択するためのデバイス構成

466    device: torch.device = DeviceConfigs()

468    encoder: EncoderRNN
469    decoder: DecoderRNN
470    optimizer: optim.Adam
471    sampler: Sampler
472
473    dataset_name: str
474    train_loader: DataLoader
475    valid_loader: DataLoader
476    train_dataset: StrokesDataset
477    valid_dataset: StrokesDataset

エンコーダとデコーダのサイズ

480    enc_hidden_size = 256
481    dec_hidden_size = 512

バッチサイズ

484    batch_size = 100

のフィーチャ数

487    d_z = 128

混合物中の分布の数、

489    n_distributions = 20

KLダイバージェンスロスの重量、

492    kl_div_loss_weight = 0.5

グラデーションクリッピング

494    grad_clip = 1.

サンプリング温度

496    temperature = 0.4

より長いストロークシーケンスを除外

499    max_seq_length = 200
500
501    epochs = 100
502
503    kl_div_loss = KLDivLoss()
504    reconstruction_loss = ReconstructionLoss()
506    def init(self):

エンコーダとデコーダを初期化

508        self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509        self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)

オプティマイザを設定します。オプティマイザーのタイプや学習率などは設定可能です

512        optimizer = OptimizerConfigs()
513        optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514        self.optimizer = optimizer

サンプラーの作成

517        self.sampler = Sampler(self.encoder, self.decoder)

npz ファイルパスは data/sketch/[DATASET NAME].npz

520        path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'

numpy ファイルを読み込む

522        dataset = np.load(str(path), encoding='latin1', allow_pickle=True)

トレーニングデータセットの作成

525        self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)

検証データセットの作成

527        self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)

トレーニングデータローダーの作成

530        self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)

検証データローダーの作成

532        self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)

Tensorboardのレイヤー出力を監視するフックを追加

535        hook_model_outputs(self.mode, self.encoder, 'encoder')
536        hook_model_outputs(self.mode, self.decoder, 'decoder')

トレイン/検証ロスの合計を出力するようにトラッカーを設定

539        tracker.set_scalar("loss.total.*", True)
540
541        self.state_modules = []
543    def step(self, batch: Any, batch_idx: BatchIndex):
544        self.encoder.train(self.mode.is_train)
545        self.decoder.train(self.mode.is_train)

mask をデバイスに移動しdata 、シーケンスとバッチディメンションを入れ替えます。data [seq_len, batch_size, 5] 形があって、mask 形があるでしょう[seq_len, batch_size]

550        data = batch[0].to(self.device).transpose(0, 1)
551        mask = batch[1].to(self.device).transpose(0, 1)

トレーニングモードでのインクリメントステップ

554        if self.mode.is_train:
555            tracker.add_global_step(len(data))

ストロークのシーケンスをエンコード

558        with monit.section("encoder"):

取得、および

560            z, mu, sigma_hat = self.encoder(data)

複数のディストリビューションをデコードし、

563        with monit.section("decoder"):

連結

565            z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566            inputs = torch.cat([data[:-1], z_stack], 2)

複数のディストリビューションを組み合わせて

568            dist, q_logits, _ = self.decoder(inputs, z, None)

損失の計算

571        with monit.section('loss'):

573            kl_loss = self.kl_div_loss(sigma_hat, mu)

575            reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)

577            loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss

トラックロス

580            tracker.add("loss.kl.", kl_loss)
581            tracker.add("loss.reconstruction.", reconstruction_loss)
582            tracker.add("loss.total.", loss)

トレーニング状態の場合のみ

585        if self.mode.is_train:

オプティマイザを実行

587            with monit.section('optimize'):

0 grad に設定

589                self.optimizer.zero_grad()

勾配の計算

591                loss.backward()

モデルパラメーターと勾配をログに記録する

593                if batch_idx.is_last:
594                    tracker.add(encoder=self.encoder, decoder=self.decoder)

クリップグラデーション

596                nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597                nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)

最適化

599                self.optimizer.step()
600
601        tracker.save()
603    def sample(self):

検証データセットからエンコーダーにサンプルをランダムに選択

605        data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]

バッチディメンションを追加してデバイスに移動

607        data = data.unsqueeze(1).to(self.device)

[サンプル]

609        self.sampler.sample(data, self.temperature)
612def main():
613    configs = Configs()
614    experiment.create(name="sketch_rnn")

設定の辞書を渡す

617    experiment.configs(configs, {
618        'optimizer.optimizer': 'Adam',

学習率をに設定しているのは、1e-3 より早く結果を確認できるからです。論文は提案していた1e-4

621        'optimizer.learning_rate': 1e-3,

データセットの名前

623        'dataset_name': 'bicycle',

トレーニング、検証、サンプリングを切り替えるためのエポック内の内部反復回数。

625        'inner_iterations': 10
626    })
627
628    with experiment.start():

実験を実行する

630        configs.run()
631
632
633if __name__ == "__main__":
634    main()