18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion

采样算法的基类

25class DiffusionSampler:
29    model: LatentDiffusion
  • model 是预测噪声的模型
31    def __init__(self, model: LatentDiffusion):
35        super().__init__()

设置模型

37        self.model = model

获取模型训练的步数

39        self.n_steps = model.n_steps

获取

  • x 是形状的[batch_size, channels, height, width]
  • t 是形状的[batch_size]
  • c 是形状的条件嵌入[batch_size, emb_size]
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
41    def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42                uncond_scale: float, uncond_cond: Optional[torch.Tensor]):

当体重秤时

55        if uncond_cond is None or uncond_scale == 1.:
56            return self.model(x, t, c)

复制

59        x_in = torch.cat([x] * 2)
60        t_in = torch.cat([t] * 2)

串联

62        c_in = torch.cat([uncond_cond, c])

获取

64        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)

计算

67        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)

70        return e_t

采样回路

  • shape 是表单中生成的图像的形状[batch_size, channels, height, width]
  • cond 是条件嵌入
  • temperature 是噪声温度(随机噪声乘以此值)
  • x_last。如果未提供,将使用随机噪声。
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
  • skip_steps 是要跳过的时间步数。
72    def sample(self,
73               shape: List[int],
74               cond: torch.Tensor,
75               repeat_noise: bool = False,
76               temperature: float = 1.,
77               x_last: Optional[torch.Tensor] = None,
78               uncond_scale: float = 1.,
79               uncond_cond: Optional[torch.Tensor] = None,
80               skip_steps: int = 0,
81               ):
95        raise NotImplementedError()

绘画循环

  • x 是形状的[batch_size, channels, height, width]
  • cond 是条件嵌入
  • t_start 是开始时的采样步骤,
  • orig 是我们正在绘制的潜在页面中的原始图像。
  • mask 是保留原始图像的掩码。
  • orig_noise 是要添加到原始图像的固定噪点。
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
97    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98              orig: Optional[torch.Tensor] = None,
99              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100              uncond_scale: float = 1.,
101              uncond_cond: Optional[torch.Tensor] = None,
102              ):
116        raise NotImplementedError()

样本来自

  • x0 是形状的[batch_size, channels, height, width]
  • index 是时间步长指数
  • noise 是噪音,
118    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126        raise NotImplementedError()