これは、論文「スケッチ図面のニューラル表現」の注釈付きPyTorch実装です。
Sketch RNN はシーケンス間の変分オートエンコーダーです。エンコーダーとデコーダーはどちらもリカレントニューラルネットワークモデルです。一連のストロークを予測することで、ストロークに基づいて簡単な図面を再構築する方法を学習します。デコーダーは、各ストロークをガウスの混合として予測します
。クイック、ドロー!からデータをダウンロードデータセット。readmeのSketch-RNN QuickDraw npz
Datasetセクションにファイルをダウンロードするためのリンクがあります。npz
data/sketch
ダウンロードしたファイルをフォルダに配置します。bicycle
このコードはデータセットを使用するように構成されています。これは設定で変更できます。
32import math
33from typing import Optional, Tuple, Any
34
35import numpy as np
36import torch
37import torch.nn as nn
38from matplotlib import pyplot as plt
39from torch import optim
40from torch.utils.data import Dataset, DataLoader
41
42import einops
43from labml import lab, experiment, tracker, monit
44from labml_helpers.device import DeviceConfigs
45from labml_helpers.module import Module
46from labml_helpers.optimizer import OptimizerConfigs
47from labml_helpers.train_valid import TrainValidConfigs, hook_model_outputs, BatchIndex
50class StrokesDataset(Dataset):
dataset
seq_len, 3 という形状のゴツゴツした配列のリストです。これは一連のストロークで、各ストロークは3つの整数で表されます。最初の 2 つは x と y に沿った変位 (,) で、最後の整数はペンの状態 (紙に触れている場合とそうでない場合) を表します
57 def __init__(self, dataset: np.array, max_seq_length: int, scale: Optional[float] = None):
67 data = []
各シーケンスとフィルターを繰り返し処理します。
69 for seq in dataset:
ストロークのシーケンスの長さが範囲内にある場合はフィルタリングします
71 if 10 < len(seq) <= max_seq_length:
クランプ、
73 seq = np.minimum(seq, 1000)
74 seq = np.maximum(seq, -1000)
浮動小数点配列に変換して追加 data
76 seq = np.array(seq, dtype=np.float32)
77 data.append(seq)
次に、(,) を組み合わせた標準偏差であるスケーリング係数を計算します。論文によると、平均はとにかくそれに近いので、簡単にするために平均を調整していないという
。83 if scale is None:
84 scale = np.std(np.concatenate([np.ravel(s[:, 0:2]) for s in data]))
85 self.scale = scale
すべてのシーケンスの中で一番長いシーケンス長を取得
88 longest_seq_len = max([len(seq) for seq in data])
PyTorchデータ配列を初期化するには、シーケンスの開始 (sos) とシーケンスの終了 (eos) の2つの追加ステップが必要です。各ステップはベクトルです。そのうちの1つだけがそうで、他はそうです。ペンダウン、ペンアップ、シーケンスの終了の順序で表されます。次のステップでペンが紙に触れた場合です。次のステップでペンが紙に触れない場合です。それがドローイングの終わりだとしたら
98 self.data = torch.zeros(len(data), longest_seq_len + 2, 5, dtype=torch.float)
マスク配列は、次のステップを取り込んで予測するデコーダーの出力用なので、追加のステップは1つだけ必要です。data[:-1]
101 self.mask = torch.zeros(len(data), longest_seq_len + 1)
102
103 for i, seq in enumerate(data):
104 seq = torch.from_numpy(seq)
105 len_seq = len(seq)
スケールとセット
107 self.data[i, 1:len_seq + 1, :2] = seq[:, :2] / scale
109 self.data[i, 1:len_seq + 1, 2] = 1 - seq[:, 2]
111 self.data[i, 1:len_seq + 1, 3] = seq[:, 2]
113 self.data[i, len_seq + 1:, 4] = 1
マスクはシーケンスの終わりまでオンです
115 self.mask[i, :len_seq + 1] = 1
シーケンスの開始は
118 self.data[:, 0, 2] = 1
データセットのサイズ
120 def __len__(self):
122 return len(self.data)
サンプルを入手
124 def __getitem__(self, idx: int):
126 return self.data[idx], self.mask[idx]
129class BivariateGaussianMixture:
139 def __init__(self, pi_logits: torch.Tensor, mu_x: torch.Tensor, mu_y: torch.Tensor,
140 sigma_x: torch.Tensor, sigma_y: torch.Tensor, rho_xy: torch.Tensor):
141 self.pi_logits = pi_logits
142 self.mu_x = mu_x
143 self.mu_y = mu_y
144 self.sigma_x = sigma_x
145 self.sigma_y = sigma_y
146 self.rho_xy = rho_xy
混合物中の分布の数、
148 @property
149 def n_distributions(self):
151 return self.pi_logits.shape[-1]
温度による調整
153 def set_temperature(self, temperature: float):
158 self.pi_logits /= temperature
160 self.sigma_x *= math.sqrt(temperature)
162 self.sigma_y *= math.sqrt(temperature)
164 def get_distribution(self):
クランプ、 NaN
Sが入らないように
166 sigma_x = torch.clamp_min(self.sigma_x, 1e-5)
167 sigma_y = torch.clamp_min(self.sigma_y, 1e-5)
168 rho_xy = torch.clamp(self.rho_xy, -1 + 1e-5, 1 - 1e-5)
手段を取得
171 mean = torch.stack([self.mu_x, self.mu_y], -1)
共分散行列を取得
173 cov = torch.stack([
174 sigma_x * sigma_x, rho_xy * sigma_x * sigma_y,
175 rho_xy * sigma_x * sigma_y, sigma_y * sigma_y
176 ], -1)
177 cov = cov.view(*sigma_y.shape, 2, 2)
2 変量正規分布を作成します。
📝 where scale_tril
をマトリックス化すると効率的です。[[a, 0], [b, c]]
ただし、わかりやすくするために、共分散マトリックスを使用します。これは、二変量分布、それらの共分散行列、および確率密度関数について詳しく知りたい場合に役立つリソースです
188 multi_dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
ロジットからカテゴリ分布を作成
191 cat_dist = torch.distributions.Categorical(logits=self.pi_logits)
194 return cat_dist, multi_dist
197class EncoderRNN(Module):
204 def __init__(self, d_z: int, enc_hidden_size: int):
205 super().__init__()
のシーケンスを入力として双方向 LSTM を作成します。
208 self.lstm = nn.LSTM(5, enc_hidden_size, bidirectional=True)
さっそくゲットしよう
210 self.mu_head = nn.Linear(2 * enc_hidden_size, d_z)
さっそくゲットしよう
212 self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
214 def forward(self, inputs: torch.Tensor, state=None):
双方向 LSTM の隠れ状態は、最後のトークンの出力を順方向に、最初のトークンの出力を逆方向に連結したものです。これが私たちが望むことです。
221 _, (hidden, cell) = self.lstm(inputs.float(), state)
状態には形があり[2, batch_size, hidden_size]
、最初の次元が方向です。それを再配置して
225 hidden = einops.rearrange(hidden, 'fb b h -> b (fb h)')
228 mu = self.mu_head(hidden)
230 sigma_hat = self.sigma_head(hidden)
232 sigma = torch.exp(sigma_hat / 2.)
[サンプル]
235 z = mu + sigma * torch.normal(mu.new_zeros(mu.shape), mu.new_ones(mu.shape))
238 return z, mu, sigma_hat
241class DecoderRNN(Module):
248 def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
249 super().__init__()
LSTM は入力として取ります
251 self.lstm = nn.LSTM(d_z + 5, dec_hidden_size)
LSTM の初期状態はです。init_state
これの線形変換は
255 self.init_state = nn.Linear(d_z, 2 * dec_hidden_size)
260 self.mixtures = nn.Linear(dec_hidden_size, 6 * n_distributions)
このヘッドはロジット用です
263 self.q_head = nn.Linear(dec_hidden_size, 3)
これは場所を計算するためです
266 self.q_log_softmax = nn.LogSoftmax(-1)
これらのパラメータは、後で参照できるように保存されます
269 self.n_distributions = n_distributions
270 self.dec_hidden_size = dec_hidden_size
272 def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
初期状態の計算
274 if state is None:
276 h, c = torch.split(torch.tanh(self.init_state(z)), self.dec_hidden_size, 1)
h
c
そして形があります[batch_size, lstm_size]
。[1, batch_size, lstm_size]
LSTMで使われている形なので、形を整えたいのです
279 state = (h.unsqueeze(0).contiguous(), c.unsqueeze(0).contiguous())
LSTM を実行してください
282 outputs, state = self.lstm(x, state)
取得
285 q_logits = self.q_log_softmax(self.q_head(outputs))
291 pi_logits, mu_x, mu_y, sigma_x, sigma_y, rho_xy = \
292 torch.split(self.mixtures(outputs), self.n_distributions, 2)
305 dist = BivariateGaussianMixture(pi_logits, mu_x, mu_y,
306 torch.exp(sigma_x), torch.exp(sigma_y), torch.tanh(rho_xy))
309 return dist, q_logits, state
312class ReconstructionLoss(Module):
317 def forward(self, mask: torch.Tensor, target: torch.Tensor,
318 dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
取得して
320 pi, mix = dist.get_distribution()
target
[seq_len, batch_size, 5]
最後の次元がフィーチャであるような形をしています。y を取得して、混合内の各分布から確率を求めたいと思います
xy
形になります [seq_len, batch_size, n_distributions, 2]
327 xy = target[:, :, 0:2].unsqueeze(-2).expand(-1, -1, dist.n_distributions, -1)
確率の計算
333 probs = torch.sum(pi.probs * torch.exp(mix.log_prob(xy)), 2)
probs
要素には (longest_seq_len
) がありますが、残りはマスクされているので、合計が計算されるだけです。
合計を取って除算しないで割る必要があるように感じるかもしれませんが、これによって、短いシーケンスで個々の予測をより重要視できるようになります。で割ると、各予測に同じ重みを与えます
。342 loss_stroke = -torch.mean(mask * torch.log(1e-5 + probs))
345 loss_pen = -torch.mean(target[:, :, 2:] * q_logits)
348 return loss_stroke + loss_pen
351class KLDivLoss(Module):
358 def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
360 return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
363class Sampler:
370 def __init__(self, encoder: EncoderRNN, decoder: DecoderRNN):
371 self.decoder = decoder
372 self.encoder = encoder
374 def sample(self, data: torch.Tensor, temperature: float):
376 longest_seq_len = len(data)
エンコーダから取得
379 z, _, _ = self.encoder(data)
シーケンスの開始ストロークは
382 s = data.new_tensor([0, 0, 1, 0, 0])
383 seq = [s]
初期デコーダーはNone
.デコーダーはそれを次のように初期化します
386 state = None
グラデーションはいらない
389 with torch.no_grad():
サンプルストローク
391 for i in range(longest_seq_len):
デコーダーへの入力です
393 data = torch.cat([s.view(1, 1, -1), z.unsqueeze(0)], 2)
デコーダーから、、および次の状態を取得
396 dist, q_logits, state = self.decoder(data, z, state)
ストロークをサンプリングする
398 s = self._sample_step(dist, q_logits, temperature)
新しいストロークをストロークのシーケンスに追加します
400 seq.append(s)
もしそうなら、サンプリングを停止してください。これはスケッチが停止したことを示します
402 if s[4] == 1:
403 break
一連のストロークの PyTorch テンソルを作成します。
406 seq = torch.stack(seq)
ストロークのシーケンスをプロット
409 self.plot(seq)
411 @staticmethod
412 def _sample_step(dist: 'BivariateGaussianMixture', q_logits: torch.Tensor, temperature: float):
サンプリングの温度を設定します。これはクラスで実装されていますBivariateGaussianMixture
。
414 dist.set_temperature(temperature)
温度調整して
416 pi, mix = dist.get_distribution()
分布の指標からサンプルを採取して、混合物から使用する
418 idx = pi.sample()[0, 0]
対数確率によるカテゴリ分布の作成または q_logits
421 q = torch.distributions.Categorical(logits=q_logits / temperature)
からのサンプル
423 q_idx = q.sample()[0, 0]
混合物の正規分布からサンプリングし、次の式でインデックス付けされた分布を選択します。idx
426 xy = mix.sample()[0, 0, idx]
空のストロークを作成
429 stroke = q_logits.new_zeros(5)
セット
431 stroke[:2] = xy
セット
433 stroke[q_idx + 2] = 1
435 return stroke
437 @staticmethod
438 def plot(seq: torch.Tensor):
の累積和を取ると、
440 seq[:, 0:2] = torch.cumsum(seq[:, 0:2], dim=0)
次の形式の新しいnumpy配列を作成します
442 seq[:, 2] = seq[:, 3]
443 seq = seq[:, 0:3].detach().cpu().numpy()
配列をあるポイントで分割します。つまり、ペンを紙から持ち上げるポイントでストロークの配列を分割します。これにより、ストロークのシーケンスのリストが表示されます
。448 strokes = np.split(seq, np.where(seq[:, 2] > 0)[0] + 1)
ストロークの各シーケンスをプロット
450 for s in strokes:
451 plt.plot(s[:, 0], -s[:, 1])
座標軸を表示しないでください
453 plt.axis('off')
プロットを表示
455 plt.show()
458class Configs(TrainValidConfigs):
実験を実行するデバイスを選択するためのデバイス構成
466 device: torch.device = DeviceConfigs()
468 encoder: EncoderRNN
469 decoder: DecoderRNN
470 optimizer: optim.Adam
471 sampler: Sampler
472
473 dataset_name: str
474 train_loader: DataLoader
475 valid_loader: DataLoader
476 train_dataset: StrokesDataset
477 valid_dataset: StrokesDataset
エンコーダとデコーダのサイズ
480 enc_hidden_size = 256
481 dec_hidden_size = 512
バッチサイズ
484 batch_size = 100
のフィーチャ数
487 d_z = 128
混合物中の分布の数、
489 n_distributions = 20
KLダイバージェンスロスの重量、
492 kl_div_loss_weight = 0.5
グラデーションクリッピング
494 grad_clip = 1.
サンプリング温度
496 temperature = 0.4
より長いストロークシーケンスを除外
499 max_seq_length = 200
500
501 epochs = 100
502
503 kl_div_loss = KLDivLoss()
504 reconstruction_loss = ReconstructionLoss()
506 def init(self):
エンコーダとデコーダを初期化
508 self.encoder = EncoderRNN(self.d_z, self.enc_hidden_size).to(self.device)
509 self.decoder = DecoderRNN(self.d_z, self.dec_hidden_size, self.n_distributions).to(self.device)
オプティマイザを設定します。オプティマイザーのタイプや学習率などは設定可能です
512 optimizer = OptimizerConfigs()
513 optimizer.parameters = list(self.encoder.parameters()) + list(self.decoder.parameters())
514 self.optimizer = optimizer
サンプラーの作成
517 self.sampler = Sampler(self.encoder, self.decoder)
npz
ファイルパスは data/sketch/[DATASET NAME].npz
520 path = lab.get_data_path() / 'sketch' / f'{self.dataset_name}.npz'
numpy ファイルを読み込む
522 dataset = np.load(str(path), encoding='latin1', allow_pickle=True)
トレーニングデータセットの作成
525 self.train_dataset = StrokesDataset(dataset['train'], self.max_seq_length)
検証データセットの作成
527 self.valid_dataset = StrokesDataset(dataset['valid'], self.max_seq_length, self.train_dataset.scale)
トレーニングデータローダーの作成
530 self.train_loader = DataLoader(self.train_dataset, self.batch_size, shuffle=True)
検証データローダーの作成
532 self.valid_loader = DataLoader(self.valid_dataset, self.batch_size)
Tensorboardのレイヤー出力を監視するフックを追加
535 hook_model_outputs(self.mode, self.encoder, 'encoder')
536 hook_model_outputs(self.mode, self.decoder, 'decoder')
トレイン/検証ロスの合計を出力するようにトラッカーを設定
539 tracker.set_scalar("loss.total.*", True)
540
541 self.state_modules = []
543 def step(self, batch: Any, batch_idx: BatchIndex):
544 self.encoder.train(self.mode.is_train)
545 self.decoder.train(self.mode.is_train)
mask
をデバイスに移動しdata
、シーケンスとバッチディメンションを入れ替えます。data
[seq_len, batch_size, 5]
形があって、mask
形があるでしょう[seq_len, batch_size]
。
550 data = batch[0].to(self.device).transpose(0, 1)
551 mask = batch[1].to(self.device).transpose(0, 1)
トレーニングモードでのインクリメントステップ
554 if self.mode.is_train:
555 tracker.add_global_step(len(data))
ストロークのシーケンスをエンコード
558 with monit.section("encoder"):
取得、および
560 z, mu, sigma_hat = self.encoder(data)
複数のディストリビューションをデコードし、
563 with monit.section("decoder"):
連結
565 z_stack = z.unsqueeze(0).expand(data.shape[0] - 1, -1, -1)
566 inputs = torch.cat([data[:-1], z_stack], 2)
複数のディストリビューションを組み合わせて
568 dist, q_logits, _ = self.decoder(inputs, z, None)
損失の計算
571 with monit.section('loss'):
573 kl_loss = self.kl_div_loss(sigma_hat, mu)
575 reconstruction_loss = self.reconstruction_loss(mask, data[1:], dist, q_logits)
577 loss = reconstruction_loss + self.kl_div_loss_weight * kl_loss
トラックロス
580 tracker.add("loss.kl.", kl_loss)
581 tracker.add("loss.reconstruction.", reconstruction_loss)
582 tracker.add("loss.total.", loss)
トレーニング状態の場合のみ
585 if self.mode.is_train:
オプティマイザを実行
587 with monit.section('optimize'):
0 grad
に設定
589 self.optimizer.zero_grad()
勾配の計算
591 loss.backward()
モデルパラメーターと勾配をログに記録する
593 if batch_idx.is_last:
594 tracker.add(encoder=self.encoder, decoder=self.decoder)
クリップグラデーション
596 nn.utils.clip_grad_norm_(self.encoder.parameters(), self.grad_clip)
597 nn.utils.clip_grad_norm_(self.decoder.parameters(), self.grad_clip)
最適化
599 self.optimizer.step()
600
601 tracker.save()
603 def sample(self):
検証データセットからエンコーダーにサンプルをランダムに選択
605 data, *_ = self.valid_dataset[np.random.choice(len(self.valid_dataset))]
バッチディメンションを追加してデバイスに移動
607 data = data.unsqueeze(1).to(self.device)
[サンプル]
609 self.sampler.sample(data, self.temperature)
612def main():
613 configs = Configs()
614 experiment.create(name="sketch_rnn")
設定の辞書を渡す
617 experiment.configs(configs, {
618 'optimizer.optimizer': 'Adam',
学習率をに設定しているのは、1e-3
より早く結果を確認できるからです。論文は提案していた1e-4
。
621 'optimizer.learning_rate': 1e-3,
データセットの名前
623 'dataset_name': 'bicycle',
トレーニング、検証、サンプリングを切り替えるためのエポック内の内部反復回数。
625 'inner_iterations': 10
626 })
627
628 with experiment.start():
実験を実行する
630 configs.run()
631
632
633if __name__ == "__main__":
634 main()