降噪扩散隐含模型 (DDIM) 采样

这实现了来自论文 “降噪扩散隐式模型” 的 DDIM 采样

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

DDIM 采样器

这扩展了DiffusionSampler 基类

DDPM 通过逐步采样来反复消除噪点来对图像进行采样,

其中,是随机噪声,是长度为的子序列

请注意,在 DDIM 论文中,指的是来DDPM 的论文。

26class DDIMSampler(DiffusionSampler):
52    model: LatentDiffusion
  • model 是预测噪声的模型
  • n_steps 是 DDIM 采样步骤的数量,
  • ddim_discretize 指定如何从中提取。可以是uniformquad
  • ddim_eta 用于计算使采样过程具有确定性。
54    def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63        super().__init__(model)

步数,

65        self.n_steps = model.n_steps

计算得均匀分布在各处

68        if ddim_discretize == 'uniform':
69            c = self.n_steps // n_steps
70            self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1

计算以二次分布

72        elif ddim_discretize == 'quad':
73            self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74        else:
75            raise NotImplementedError(ddim_discretize)
76
77        with torch.no_grad():

获取

79            alpha_bar = self.model.alpha_bar

82            self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)

84            self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)

86            self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])

91            self.ddim_sigma = (ddim_eta *
92                               ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93                                (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)

96            self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5

采样回路

  • shape 是表单中生成的图像的形状[batch_size, channels, height, width]
  • cond 是条件嵌入
  • temperature 是噪声温度(随机噪声乘以此值)
  • x_last。如果未提供,将使用随机噪声。
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
  • skip_steps 是要跳过的时间步数。我们从开始采样。然后x_last 就是这样
98    @torch.no_grad()
99    def sample(self,
100               shape: List[int],
101               cond: torch.Tensor,
102               repeat_noise: bool = False,
103               temperature: float = 1.,
104               x_last: Optional[torch.Tensor] = None,
105               uncond_scale: float = 1.,
106               uncond_cond: Optional[torch.Tensor] = None,
107               skip_steps: int = 0,
108               ):

获取设备和批次大小

125        device = self.model.device
126        bs = shape[0]

获取

129        x = x_last if x_last is not None else torch.randn(shape, device=device)

采样的时间步长

132        time_steps = np.flip(self.time_steps)[skip_steps:]
133
134        for i, step in monit.enum('Sample', time_steps):

列表中的索引

136            index = len(time_steps) - i - 1

时间步长

138            ts = x.new_full((bs,), step, dtype=torch.long)

示例

141            x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142                                            repeat_noise=repeat_noise,
143                                            temperature=temperature,
144                                            uncond_scale=uncond_scale,
145                                            uncond_cond=uncond_cond)

返回

148        return x

示例

  • x 是形状的[batch_size, channels, height, width]
  • c 是形状的条件嵌入[batch_size, emb_size]
  • t 是形状的[batch_size]
  • step 是整数形式的步长
  • index 是列表中的索引
  • repeat_noise 指定批次中所有样本的噪声是否应相同
  • temperature 是噪声温度(随机噪声乘以此值)
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
150    @torch.no_grad()
151    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152                 repeat_noise: bool = False,
153                 temperature: float = 1.,
154                 uncond_scale: float = 1.,
155                 uncond_cond: Optional[torch.Tensor] = None):

获取

172        e_t = self.get_eps(x, t, c,
173                           uncond_scale=uncond_scale,
174                           uncond_cond=uncond_cond)

计算和预测

177        x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178                                                      temperature=temperature,
179                                                      repeat_noise=repeat_noise)

182        return x_prev, pred_x0, e_t

给出的样本

184    def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185                               temperature: float,
186                               repeat_noise: bool):

192        alpha = self.ddim_alpha[index]

194        alpha_prev = self.ddim_alpha_prev[index]

196        sigma = self.ddim_sigma[index]

198        sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

目前的预测

202        pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)

指向的方向

205        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

在以下情况下不添加任何噪音

208        if sigma == 0.:
209            noise = 0.

如果批次中的所有样品都使用相同的噪声

211        elif repeat_noise:
212            noise = torch.randn((1, *x.shape[1:]), device=x.device)

每个样本的噪声不同

214        else:
215            noise = torch.randn(x.shape, device=x.device)

将噪声乘以温度

218        noise = noise * temperature

227        x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise

230        return x_prev, pred_x0

样本来自

  • x0 是形状的[batch_size, channels, height, width]
  • index 是时间步长指数
  • noise 是噪音,
232    @torch.no_grad()
233    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

如果未指定噪声,则为随机噪声

246        if noise is None:
247            noise = torch.randn_like(x0)

样本来自

252        return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise

绘画循环

  • x 是形状的[batch_size, channels, height, width]
  • cond 是条件嵌入
  • t_start 是开始时的采样步骤,
  • orig 是我们正在绘制的潜在页面中的原始图像。如果未提供,则将是图像到图像的转换。
  • mask 是保留原始图像的掩码。
  • orig_noise 是要添加到原始图像的固定噪点。
  • uncond_scale 是无条件指导量表。这用于
  • uncond_cond 是空提示的条件嵌入
254    @torch.no_grad()
255    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256              orig: Optional[torch.Tensor] = None,
257              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258              uncond_scale: float = 1.,
259              uncond_cond: Optional[torch.Tensor] = None,
260              ):

获取批次大小

276        bs = x.shape[0]

采样的时间步长

279        time_steps = np.flip(self.time_steps[:t_start])
280
281        for i, step in monit.enum('Paint', time_steps):

列表中的索引

283            index = len(time_steps) - i - 1

时间步长

285            ts = x.new_full((bs,), step, dtype=torch.long)

示例

288            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289                                    uncond_scale=uncond_scale,
290                                    uncond_cond=uncond_cond)

将蒙版区域替换为原始图像

293            if orig is not None:

在潜在空间中获取原始图像

295                orig_t = self.q_sample(orig, index, noise=orig_noise)

替换被屏蔽的区域

297                x = orig_t * mask + x * (1 - mask)

300        return x