ノイズ除去拡散暗黙モデル (DDIM) サンプリシットサンプリシット

これは、論文「ノイズ除去拡散暗黙モデル」からのDDIMサンプリングを実装しています。

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

DDIM サンプラー

DiffusionSampler これは基本クラスを拡張します

DDPMは、以下を使用して段階的にサンプリングすることにより、ノイズを繰り返し除去することによって画像をサンプリングします。

ここで、はランダムノイズ、は長さのサブシーケンス

なお、 DDIMの論文ではDDPMのものを指しています。

26class DDIMSampler(DiffusionSampler):
52    model: LatentDiffusion
  • model ノイズを予測するモデルです
  • n_steps は DDIM サンプリングステップの数、
  • ddim_discretize からの抽出方法を指定しますuniform またはのどちらでもかまいませんquad
  • ddim_eta 計算に使用されますサンプリングプロセスを決定的にします
54    def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63        super().__init__(model)

ステップ数、

65        self.n_steps = model.n_steps

全体に均等に分散するように計算

68        if ddim_discretize == 'uniform':
69            c = self.n_steps // n_steps
70            self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1

2 次分布になるように計算する

72        elif ddim_discretize == 'quad':
73            self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74        else:
75            raise NotImplementedError(ddim_discretize)
76
77        with torch.no_grad():

取得

79            alpha_bar = self.model.alpha_bar

82            self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)

84            self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)

86            self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])

91            self.ddim_sigma = (ddim_eta *
92                               ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93                                (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)

96            self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5

サンプリングループ

  • shape フォームで生成されたイメージの形状です [batch_size, channels, height, width]
  • cond 条件付き埋め込みです
  • temperature はノイズ温度 (ランダムノイズにこれを掛けます)
  • x_last です。指定しない場合は、ランダムノイズが使用されます。
  • uncond_scale 無条件ガイダンススケールです。これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
  • skip_steps スキップするタイムステップの数です。からサンプリングを開始します。そして、x_last その時です
98    @torch.no_grad()
99    def sample(self,
100               shape: List[int],
101               cond: torch.Tensor,
102               repeat_noise: bool = False,
103               temperature: float = 1.,
104               x_last: Optional[torch.Tensor] = None,
105               uncond_scale: float = 1.,
106               uncond_cond: Optional[torch.Tensor] = None,
107               skip_steps: int = 0,
108               ):

デバイスとバッチサイズの取得

125        device = self.model.device
126        bs = shape[0]

取得

129        x = x_last if x_last is not None else torch.randn(shape, device=device)

サンプリングするタイムステップ

132        time_steps = np.flip(self.time_steps)[skip_steps:]
133
134        for i, step in monit.enum('Sample', time_steps):

リスト内のインデックス

136            index = len(time_steps) - i - 1

タイムステップ

138            ts = x.new_full((bs,), step, dtype=torch.long)

[サンプル]

141            x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142                                            repeat_noise=repeat_noise,
143                                            temperature=temperature,
144                                            uncond_scale=uncond_scale,
145                                            uncond_cond=uncond_cond)

リターン

148        return x

[サンプル]

  • x 形が合っている [batch_size, channels, height, width]
  • c 形状の条件付き埋め込みです [batch_size, emb_size]
  • t 形が合っている [batch_size]
  • step はステップを整数で表したものです
  • index リスト内のインデックスです
  • repeat_noise バッチ内のすべてのサンプルでノイズを同じにするかどうかを指定しました
  • temperature はノイズ温度 (ランダムノイズにこれを掛けます)
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
150    @torch.no_grad()
151    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152                 repeat_noise: bool = False,
153                 temperature: float = 1.,
154                 uncond_scale: float = 1.,
155                 uncond_cond: Optional[torch.Tensor] = None):

取得

172        e_t = self.get_eps(x, t, c,
173                           uncond_scale=uncond_scale,
174                           uncond_cond=uncond_cond)

計算と予測

177        x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178                                                      temperature=temperature,
179                                                      repeat_noise=repeat_noise)

182        return x_prev, pred_x0, e_t

サンプル提供

184    def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185                               temperature: float,
186                               repeat_noise: bool):

192        alpha = self.ddim_alpha[index]

194        alpha_prev = self.ddim_alpha_prev[index]

196        sigma = self.ddim_sigma[index]

198        sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

の現在の予測

202        pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)

を指す方向

205        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

次の場合、ノイズは追加されません

208        if sigma == 0.:
209            noise = 0.

バッチ内のすべてのサンプルに同じノイズが使用されている場合

211        elif repeat_noise:
212            noise = torch.randn((1, *x.shape[1:]), device=x.device)

サンプルごとに異なるノイズ

214        else:
215            noise = torch.randn(x.shape, device=x.device)

ノイズに温度を掛ける

218        noise = noise * temperature

227        x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise

230        return x_prev, pred_x0

からのサンプル

  • x0 形が合っている [batch_size, channels, height, width]
  • index はタイムステップインデックス
  • noise ノイズは、
232    @torch.no_grad()
233    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

ランダムノイズ (ノイズが指定されていない場合)

246        if noise is None:
247            noise = torch.randn_like(x0)

からのサンプル

252        return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise

ペインティングループ

  • x 形が合っている [batch_size, channels, height, width]
  • cond 条件付き埋め込みです
  • t_start 開始するサンプリングステップです
  • orig 現在ペイント中の潜在ページのオリジナル画像です。これが指定されていない場合は、画像から画像への変換になります。
  • mask 元の画像を残すためのマスクです。
  • orig_noise 元の画像に追加される固定ノイズです。
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
254    @torch.no_grad()
255    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256              orig: Optional[torch.Tensor] = None,
257              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258              uncond_scale: float = 1.,
259              uncond_cond: Optional[torch.Tensor] = None,
260              ):

バッチサイズを取得

276        bs = x.shape[0]

サンプリングするタイムステップ

279        time_steps = np.flip(self.time_steps[:t_start])
280
281        for i, step in monit.enum('Paint', time_steps):

リスト内のインデックス

283            index = len(time_steps) - i - 1

タイムステップ

285            ts = x.new_full((bs,), step, dtype=torch.long)

[サンプル]

288            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289                                    uncond_scale=uncond_scale,
290                                    uncond_cond=uncond_cond)

マスクされた領域を元の画像に置き換える

293            if orig is not None:

潜在空間のオリジナル画像を取得

295                orig_t = self.q_sample(orig, index, noise=orig_noise)

マスクされた領域を置き換える

297                x = orig_t * mask + x * (1 - mask)

300        return x