18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
25class DiffusionSampler:
29 model: LatentDiffusion
model
ノイズを予測するモデルです 31 def __init__(self, model: LatentDiffusion):
35 super().__init__()
モデルを設定する
37 self.model = model
モデルのトレーニングに使用したステップ数を取得
39 self.n_steps = model.n_steps
x
形が合っている [batch_size, channels, height, width]
t
形が合っている [batch_size]
c
形状の条件付き埋め込みです [batch_size, emb_size]
uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです 41 def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42 uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
体重計のとき
55 if uncond_cond is None or uncond_scale == 1.:
56 return self.model(x, t, c)
複製と
59 x_in = torch.cat([x] * 2)
60 t_in = torch.cat([t] * 2)
連結と
62 c_in = torch.cat([uncond_cond, c])
取得して
64 e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
計算
67 e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
70 return e_t
shape
フォームで生成されたイメージの形状です [batch_size, channels, height, width]
cond
条件付き埋め込みです temperature
はノイズ温度 (ランダムノイズにこれを掛けます)x_last
です。指定しない場合は、ランダムノイズが使用されます。uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです skip_steps
スキップするタイムステップの数です。72 def sample(self,
73 shape: List[int],
74 cond: torch.Tensor,
75 repeat_noise: bool = False,
76 temperature: float = 1.,
77 x_last: Optional[torch.Tensor] = None,
78 uncond_scale: float = 1.,
79 uncond_cond: Optional[torch.Tensor] = None,
80 skip_steps: int = 0,
81 ):
95 raise NotImplementedError()
x
形が合っている [batch_size, channels, height, width]
cond
条件付き埋め込みです t_start
開始するサンプリングステップです orig
現在ペイント中の潜在ページのオリジナル画像です。mask
元の画像を残すためのマスクです。orig_noise
元の画像に追加される固定ノイズです。uncond_scale
無条件ガイダンススケールです これは次の用途に使用されます uncond_cond
空のプロンプトの条件付き埋め込みです 97 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98 orig: Optional[torch.Tensor] = None,
99 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100 uncond_scale: float = 1.,
101 uncond_cond: Optional[torch.Tensor] = None,
102 ):
116 raise NotImplementedError()
118 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126 raise NotImplementedError()