ノイズ除去拡散確率モデル (DDPM) サンプリング

よりシンプルな DDPM 実装については、当社の DDPM 実装を参照してください。スケジュールなどにも同じ表記を使います

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

DPM サンプラー

DiffusionSampler これは基本クラスを拡張します

DDPMは、以下から段階的にサンプリングすることにより、ノイズを繰り返し除去して画像をサンプリングします。

26class DDPMSampler(DiffusionSampler):
49    model: LatentDiffusion
  • model ノイズを予測するモデルです
51    def __init__(self, model: LatentDiffusion):
55        super().__init__(model)

サンプリングステップ

58        self.time_steps = np.asarray(list(range(self.n_steps)))
59
60        with torch.no_grad():

62            alpha_bar = self.model.alpha_bar

スケジュール

64            beta = self.model.beta

66            alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])

69            self.sqrt_alpha_bar = alpha_bar ** .5

71            self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .5

73            self.sqrt_recip_alpha_bar = alpha_bar ** -.5

75            self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5

78            variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)

クランピング・ログ・オン

80            self.log_var = torch.log(torch.clamp(variance, min=1e-20))

82            self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)

84            self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)

サンプリングループ

  • shape フォームで生成されたイメージの形状です [batch_size, channels, height, width]
  • cond 条件付き埋め込みです
  • temperature はノイズ温度 (ランダムノイズにこれを掛けます)
  • x_last です。指定しない場合は、ランダムノイズが使用されます。
  • uncond_scale 無条件ガイダンススケールです。これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
  • skip_steps スキップするタイムステップの数です。からサンプリングを開始します。そして、x_last その時です
86    @torch.no_grad()
87    def sample(self,
88               shape: List[int],
89               cond: torch.Tensor,
90               repeat_noise: bool = False,
91               temperature: float = 1.,
92               x_last: Optional[torch.Tensor] = None,
93               uncond_scale: float = 1.,
94               uncond_cond: Optional[torch.Tensor] = None,
95               skip_steps: int = 0,
96               ):

デバイスとバッチサイズの取得

113        device = self.model.device
114        bs = shape[0]

取得

117        x = x_last if x_last is not None else torch.randn(shape, device=device)

サンプリングするタイムステップ

120        time_steps = np.flip(self.time_steps)[skip_steps:]

サンプリングループ

123        for step in monit.iterate('Sample', time_steps):

タイムステップ

125            ts = x.new_full((bs,), step, dtype=torch.long)

[サンプル]

128            x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129                                            repeat_noise=repeat_noise,
130                                            temperature=temperature,
131                                            uncond_scale=uncond_scale,
132                                            uncond_cond=uncond_cond)

リターン

135        return x

からのサンプル

  • x 形が合っている [batch_size, channels, height, width]
  • c 形状の条件付き埋め込みです [batch_size, emb_size]
  • t 形が合っている [batch_size]
  • step はステップを整数で表したもの:repeat_noise: バッチ内のすべてのサンプルでノイズを同じにするかどうかを指定します
  • temperature はノイズ温度 (ランダムノイズにこれを掛けます)
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
137    @torch.no_grad()
138    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139                 repeat_noise: bool = False,
140                 temperature: float = 1.,
141                 uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):

取得

157        e_t = self.get_eps(x, t, c,
158                           uncond_scale=uncond_scale,
159                           uncond_cond=uncond_cond)

バッチサイズを取得

162        bs = x.shape[0]

165        sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])

167        sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])

現在の値で計算

172        x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t

175        mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])

177        mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])

計算

183        mean = mean_x0_coef * x0 + mean_xt_coef * x

185        log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])

(最終段階のサンプリング処理)時は、ノイズを加えないでください。step 0 その時であることに注意してください

189        if step == 0:
190            noise = 0

バッチ内のすべてのサンプルに同じノイズが使用されている場合

192        elif repeat_noise:
193            noise = torch.randn((1, *x.shape[1:]))

サンプルごとに異なるノイズ

195        else:
196            noise = torch.randn(x.shape)

ノイズに温度を掛ける

199        noise = noise * temperature

からのサンプル

204        x_prev = mean + (0.5 * log_var).exp() * noise

207        return x_prev, x0, e_t

からのサンプル

  • x0 形が合っている [batch_size, channels, height, width]
  • index はタイムステップインデックス
  • noise ノイズは、
209    @torch.no_grad()
210    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

ランダムノイズ (ノイズが指定されていない場合)

222        if noise is None:
223            noise = torch.randn_like(x0)

からのサンプル

226        return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise