16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
26class DDPMSampler(DiffusionSampler):
49 model: LatentDiffusion
model
ノイズを予測するモデルです 51 def __init__(self, model: LatentDiffusion):
55 super().__init__(model)
サンプリングステップ
58 self.time_steps = np.asarray(list(range(self.n_steps)))
59
60 with torch.no_grad():
62 alpha_bar = self.model.alpha_bar
スケジュール
64 beta = self.model.beta
66 alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
69 self.sqrt_alpha_bar = alpha_bar ** .5
71 self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .5
73 self.sqrt_recip_alpha_bar = alpha_bar ** -.5
75 self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5
78 variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)
クランピング・ログ・オン
80 self.log_var = torch.log(torch.clamp(variance, min=1e-20))
82 self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)
84 self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)
shape
フォームで生成されたイメージの形状です [batch_size, channels, height, width]
cond
条件付き埋め込みです temperature
はノイズ温度 (ランダムノイズにこれを掛けます)x_last
です。指定しない場合は、ランダムノイズが使用されます。uncond_scale
無条件ガイダンススケールです。これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです skip_steps
スキップするタイムステップの数です。からサンプリングを開始します。そして、x_last
その時です。86 @torch.no_grad()
87 def sample(self,
88 shape: List[int],
89 cond: torch.Tensor,
90 repeat_noise: bool = False,
91 temperature: float = 1.,
92 x_last: Optional[torch.Tensor] = None,
93 uncond_scale: float = 1.,
94 uncond_cond: Optional[torch.Tensor] = None,
95 skip_steps: int = 0,
96 ):
デバイスとバッチサイズの取得
113 device = self.model.device
114 bs = shape[0]
取得
117 x = x_last if x_last is not None else torch.randn(shape, device=device)
サンプリングするタイムステップ
120 time_steps = np.flip(self.time_steps)[skip_steps:]
サンプリングループ
123 for step in monit.iterate('Sample', time_steps):
タイムステップ
125 ts = x.new_full((bs,), step, dtype=torch.long)
[サンプル]
128 x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129 repeat_noise=repeat_noise,
130 temperature=temperature,
131 uncond_scale=uncond_scale,
132 uncond_cond=uncond_cond)
リターン
135 return x
x
形が合っている [batch_size, channels, height, width]
c
形状の条件付き埋め込みです [batch_size, emb_size]
t
形が合っている [batch_size]
step
はステップを整数で表したもの:repeat_noise: バッチ内のすべてのサンプルでノイズを同じにするかどうかを指定しますtemperature
はノイズ温度 (ランダムノイズにこれを掛けます)uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです 137 @torch.no_grad()
138 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139 repeat_noise: bool = False,
140 temperature: float = 1.,
141 uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):
取得
157 e_t = self.get_eps(x, t, c,
158 uncond_scale=uncond_scale,
159 uncond_cond=uncond_cond)
バッチサイズを取得
162 bs = x.shape[0]
165 sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])
167 sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])
172 x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t
175 mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
177 mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])
183 mean = mean_x0_coef * x0 + mean_xt_coef * x
185 log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])
(最終段階のサンプリング処理)時は、ノイズを加えないでください。step
0
その時であることに注意してください)
189 if step == 0:
190 noise = 0
バッチ内のすべてのサンプルに同じノイズが使用されている場合
192 elif repeat_noise:
193 noise = torch.randn((1, *x.shape[1:]))
サンプルごとに異なるノイズ
195 else:
196 noise = torch.randn(x.shape)
ノイズに温度を掛ける
199 noise = noise * temperature
204 x_prev = mean + (0.5 * log_var).exp() * noise
207 return x_prev, x0, e_t
209 @torch.no_grad()
210 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
ランダムノイズ (ノイズが指定されていない場合)
222 if noise is None:
223 noise = torch.randn_like(x0)
からのサンプル
226 return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise