16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
这扩展了DiffusionSampler
基类。
DDPM 通过逐步采样来反复消除噪点来对图像进行采样,
其中,是随机噪声,是长度为的子序列,
请注意,在 DDIM 论文中,指的是来自 DDPM 的论文。
26class DDIMSampler(DiffusionSampler):
52 model: LatentDiffusion
model
是预测噪声的模型n_steps
是 DDIM 采样步骤的数量,ddim_discretize
指定如何从中提取。可以是uniform
或quad
。ddim_eta
用于计算。使采样过程具有确定性。54 def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63 super().__init__(model)
步数,
65 self.n_steps = model.n_steps
计算得均匀分布在各处
68 if ddim_discretize == 'uniform':
69 c = self.n_steps // n_steps
70 self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
计算以二次分布
72 elif ddim_discretize == 'quad':
73 self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74 else:
75 raise NotImplementedError(ddim_discretize)
76
77 with torch.no_grad():
获取
79 alpha_bar = self.model.alpha_bar
82 self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)
84 self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
86 self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])
91 self.ddim_sigma = (ddim_eta *
92 ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93 (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)
96 self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5
shape
是表单中生成的图像的形状[batch_size, channels, height, width]
cond
是条件嵌入temperature
是噪声温度(随机噪声乘以此值)x_last
是。如果未提供,将使用随机噪声。uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入skip_steps
是要跳过的时间步数。我们从开始采样。然后x_last
就是这样。98 @torch.no_grad()
99 def sample(self,
100 shape: List[int],
101 cond: torch.Tensor,
102 repeat_noise: bool = False,
103 temperature: float = 1.,
104 x_last: Optional[torch.Tensor] = None,
105 uncond_scale: float = 1.,
106 uncond_cond: Optional[torch.Tensor] = None,
107 skip_steps: int = 0,
108 ):
获取设备和批次大小
125 device = self.model.device
126 bs = shape[0]
获取
129 x = x_last if x_last is not None else torch.randn(shape, device=device)
采样的时间步长
132 time_steps = np.flip(self.time_steps)[skip_steps:]
133
134 for i, step in monit.enum('Sample', time_steps):
列表中的索引
136 index = len(time_steps) - i - 1
时间步长
138 ts = x.new_full((bs,), step, dtype=torch.long)
示例
141 x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142 repeat_noise=repeat_noise,
143 temperature=temperature,
144 uncond_scale=uncond_scale,
145 uncond_cond=uncond_cond)
返回
148 return x
x
是形状的[batch_size, channels, height, width]
c
是形状的条件嵌入[batch_size, emb_size]
t
是形状的[batch_size]
step
是整数形式的步长index
是列表中的索引repeat_noise
指定批次中所有样本的噪声是否应相同temperature
是噪声温度(随机噪声乘以此值)uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入150 @torch.no_grad()
151 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152 repeat_noise: bool = False,
153 temperature: float = 1.,
154 uncond_scale: float = 1.,
155 uncond_cond: Optional[torch.Tensor] = None):
获取
172 e_t = self.get_eps(x, t, c,
173 uncond_scale=uncond_scale,
174 uncond_cond=uncond_cond)
计算和预测
177 x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178 temperature=temperature,
179 repeat_noise=repeat_noise)
182 return x_prev, pred_x0, e_t
184 def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185 temperature: float,
186 repeat_noise: bool):
192 alpha = self.ddim_alpha[index]
194 alpha_prev = self.ddim_alpha_prev[index]
196 sigma = self.ddim_sigma[index]
198 sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
目前的预测,
202 pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
指向的方向
205 dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
在以下情况下不添加任何噪音
208 if sigma == 0.:
209 noise = 0.
如果批次中的所有样品都使用相同的噪声
211 elif repeat_noise:
212 noise = torch.randn((1, *x.shape[1:]), device=x.device)
每个样本的噪声不同
214 else:
215 noise = torch.randn(x.shape, device=x.device)
将噪声乘以温度
218 noise = noise * temperature
227 x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
230 return x_prev, pred_x0
232 @torch.no_grad()
233 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
如果未指定噪声,则为随机噪声
246 if noise is None:
247 noise = torch.randn_like(x0)
样本来自
252 return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise
x
是形状的[batch_size, channels, height, width]
cond
是条件嵌入t_start
是开始时的采样步骤,orig
是我们正在绘制的潜在页面中的原始图像。如果未提供,则将是图像到图像的转换。mask
是保留原始图像的掩码。orig_noise
是要添加到原始图像的固定噪点。uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入254 @torch.no_grad()
255 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256 orig: Optional[torch.Tensor] = None,
257 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258 uncond_scale: float = 1.,
259 uncond_cond: Optional[torch.Tensor] = None,
260 ):
获取批次大小
276 bs = x.shape[0]
采样的时间步长
279 time_steps = np.flip(self.time_steps[:t_start])
280
281 for i, step in monit.enum('Paint', time_steps):
列表中的索引
283 index = len(time_steps) - i - 1
时间步长
285 ts = x.new_full((bs,), step, dtype=torch.long)
示例
288 x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289 uncond_scale=uncond_scale,
290 uncond_cond=uncond_cond)
将蒙版区域替换为原始图像
293 if orig is not None:
在潜在空间中获取原始图像
295 orig_t = self.q_sample(orig, index, noise=orig_noise)
替换被屏蔽的区域
297 x = orig_t * mask + x * (1 - mask)
300 return x