16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
DiffusionSampler
これは基本クラスを拡張します。
DDPMは、以下を使用して段階的にサンプリングすることにより、ノイズを繰り返し除去することによって画像をサンプリングします。
ここで、はランダムノイズ、は長さのサブシーケンス、
26class DDIMSampler(DiffusionSampler):
52 model: LatentDiffusion
model
ノイズを予測するモデルです n_steps
は DDIM サンプリングステップの数、ddim_discretize
からの抽出方法を指定します。uniform
またはのどちらでもかまいませんquad
。ddim_eta
計算に使用されます。サンプリングプロセスを決定的にします54 def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63 super().__init__(model)
ステップ数、
65 self.n_steps = model.n_steps
全体に均等に分散するように計算
68 if ddim_discretize == 'uniform':
69 c = self.n_steps // n_steps
70 self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
2 次分布になるように計算する
72 elif ddim_discretize == 'quad':
73 self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74 else:
75 raise NotImplementedError(ddim_discretize)
76
77 with torch.no_grad():
取得
79 alpha_bar = self.model.alpha_bar
82 self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)
84 self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
86 self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])
91 self.ddim_sigma = (ddim_eta *
92 ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93 (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)
96 self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5
shape
フォームで生成されたイメージの形状です [batch_size, channels, height, width]
cond
条件付き埋め込みです temperature
はノイズ温度 (ランダムノイズにこれを掛けます)x_last
です。指定しない場合は、ランダムノイズが使用されます。uncond_scale
無条件ガイダンススケールです。これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです skip_steps
スキップするタイムステップの数です。からサンプリングを開始します。そして、x_last
その時です。98 @torch.no_grad()
99 def sample(self,
100 shape: List[int],
101 cond: torch.Tensor,
102 repeat_noise: bool = False,
103 temperature: float = 1.,
104 x_last: Optional[torch.Tensor] = None,
105 uncond_scale: float = 1.,
106 uncond_cond: Optional[torch.Tensor] = None,
107 skip_steps: int = 0,
108 ):
デバイスとバッチサイズの取得
125 device = self.model.device
126 bs = shape[0]
取得
129 x = x_last if x_last is not None else torch.randn(shape, device=device)
サンプリングするタイムステップ
132 time_steps = np.flip(self.time_steps)[skip_steps:]
133
134 for i, step in monit.enum('Sample', time_steps):
リスト内のインデックス
136 index = len(time_steps) - i - 1
タイムステップ
138 ts = x.new_full((bs,), step, dtype=torch.long)
[サンプル]
141 x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142 repeat_noise=repeat_noise,
143 temperature=temperature,
144 uncond_scale=uncond_scale,
145 uncond_cond=uncond_cond)
リターン
148 return x
x
形が合っている [batch_size, channels, height, width]
c
形状の条件付き埋め込みです [batch_size, emb_size]
t
形が合っている [batch_size]
step
はステップを整数で表したものですindex
リスト内のインデックスです repeat_noise
バッチ内のすべてのサンプルでノイズを同じにするかどうかを指定しましたtemperature
はノイズ温度 (ランダムノイズにこれを掛けます)uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです 150 @torch.no_grad()
151 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152 repeat_noise: bool = False,
153 temperature: float = 1.,
154 uncond_scale: float = 1.,
155 uncond_cond: Optional[torch.Tensor] = None):
取得
172 e_t = self.get_eps(x, t, c,
173 uncond_scale=uncond_scale,
174 uncond_cond=uncond_cond)
計算と予測
177 x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178 temperature=temperature,
179 repeat_noise=repeat_noise)
182 return x_prev, pred_x0, e_t
184 def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185 temperature: float,
186 repeat_noise: bool):
192 alpha = self.ddim_alpha[index]
194 alpha_prev = self.ddim_alpha_prev[index]
196 sigma = self.ddim_sigma[index]
198 sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
の現在の予測
202 pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
を指す方向
205 dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
次の場合、ノイズは追加されません
208 if sigma == 0.:
209 noise = 0.
バッチ内のすべてのサンプルに同じノイズが使用されている場合
211 elif repeat_noise:
212 noise = torch.randn((1, *x.shape[1:]), device=x.device)
サンプルごとに異なるノイズ
214 else:
215 noise = torch.randn(x.shape, device=x.device)
ノイズに温度を掛ける
218 noise = noise * temperature
227 x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
230 return x_prev, pred_x0
232 @torch.no_grad()
233 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
ランダムノイズ (ノイズが指定されていない場合)
246 if noise is None:
247 noise = torch.randn_like(x0)
からのサンプル
252 return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise
x
形が合っている [batch_size, channels, height, width]
cond
条件付き埋め込みです t_start
開始するサンプリングステップです orig
現在ペイント中の潜在ページのオリジナル画像です。これが指定されていない場合は、画像から画像への変換になります。mask
元の画像を残すためのマスクです。orig_noise
元の画像に追加される固定ノイズです。uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです 254 @torch.no_grad()
255 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256 orig: Optional[torch.Tensor] = None,
257 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258 uncond_scale: float = 1.,
259 uncond_cond: Optional[torch.Tensor] = None,
260 ):
バッチサイズを取得
276 bs = x.shape[0]
サンプリングするタイムステップ
279 time_steps = np.flip(self.time_steps[:t_start])
280
281 for i, step in monit.enum('Paint', time_steps):
リスト内のインデックス
283 index = len(time_steps) - i - 1
タイムステップ
285 ts = x.new_full((bs,), step, dtype=torch.long)
[サンプル]
288 x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289 uncond_scale=uncond_scale,
290 uncond_cond=uncond_cond)
マスクされた領域を元の画像に置き換える
293 if orig is not None:
潜在空間のオリジナル画像を取得
295 orig_t = self.q_sample(orig, index, noise=orig_noise)
マスクされた領域を置き換える
297 x = orig_t * mask + x * (1 - mask)
300 return x