18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion

サンプリングアルゴリズムの基本クラス

25class DiffusionSampler:
29    model: LatentDiffusion
  • model ノイズを予測するモデルです
31    def __init__(self, model: LatentDiffusion):
35        super().__init__()

モデルを設定する

37        self.model = model

モデルのトレーニングに使用したステップ数を取得

39        self.n_steps = model.n_steps

取得

  • x 形が合っている [batch_size, channels, height, width]
  • t 形が合っている [batch_size]
  • c 形状の条件付き埋め込みです [batch_size, emb_size]
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
41    def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42                uncond_scale: float, uncond_cond: Optional[torch.Tensor]):

体重計のとき

55        if uncond_cond is None or uncond_scale == 1.:
56            return self.model(x, t, c)

複製と

59        x_in = torch.cat([x] * 2)
60        t_in = torch.cat([t] * 2)

連結と

62        c_in = torch.cat([uncond_cond, c])

取得して

64        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)

計算

67        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)

70        return e_t

サンプリングループ

  • shape フォームで生成されたイメージの形状です [batch_size, channels, height, width]
  • cond 条件付き埋め込みです
  • temperature はノイズ温度 (ランダムノイズにこれを掛けます)
  • x_last です。指定しない場合は、ランダムノイズが使用されます。
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
  • skip_steps スキップするタイムステップの数です。
72    def sample(self,
73               shape: List[int],
74               cond: torch.Tensor,
75               repeat_noise: bool = False,
76               temperature: float = 1.,
77               x_last: Optional[torch.Tensor] = None,
78               uncond_scale: float = 1.,
79               uncond_cond: Optional[torch.Tensor] = None,
80               skip_steps: int = 0,
81               ):
95        raise NotImplementedError()

ペインティングループ

  • x 形が合っている [batch_size, channels, height, width]
  • cond 条件付き埋め込みです
  • t_start 開始するサンプリングステップです
  • orig 現在ペイント中の潜在ページのオリジナル画像です。
  • mask 元の画像を残すためのマスクです。
  • orig_noise 元の画像に追加される固定ノイズです。
  • uncond_scale 無条件ガイダンススケールです これは次の用途に使用されます
  • uncond_cond 空のプロンプトの条件付き埋め込みです
97    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98              orig: Optional[torch.Tensor] = None,
99              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100              uncond_scale: float = 1.,
101              uncond_cond: Optional[torch.Tensor] = None,
102              ):
116        raise NotImplementedError()

からのサンプル

  • x0 形が合っている [batch_size, channels, height, width]
  • index はタイムステップインデックス
  • noise ノイズは、
118    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126        raise NotImplementedError()