18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
25class DiffusionSampler:
29 model: LatentDiffusion
model
是预测噪声的模型31 def __init__(self, model: LatentDiffusion):
35 super().__init__()
设置模型
37 self.model = model
获取模型训练的步数
39 self.n_steps = model.n_steps
x
是形状的[batch_size, channels, height, width]
t
是形状的[batch_size]
c
是形状的条件嵌入[batch_size, emb_size]
uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入41 def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42 uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
当体重秤时
55 if uncond_cond is None or uncond_scale == 1.:
56 return self.model(x, t, c)
复制和
59 x_in = torch.cat([x] * 2)
60 t_in = torch.cat([t] * 2)
串联和
62 c_in = torch.cat([uncond_cond, c])
获取和
64 e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
计算
67 e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
70 return e_t
shape
是表单中生成的图像的形状[batch_size, channels, height, width]
cond
是条件嵌入temperature
是噪声温度(随机噪声乘以此值)x_last
是。如果未提供,将使用随机噪声。uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入skip_steps
是要跳过的时间步数。72 def sample(self,
73 shape: List[int],
74 cond: torch.Tensor,
75 repeat_noise: bool = False,
76 temperature: float = 1.,
77 x_last: Optional[torch.Tensor] = None,
78 uncond_scale: float = 1.,
79 uncond_cond: Optional[torch.Tensor] = None,
80 skip_steps: int = 0,
81 ):
95 raise NotImplementedError()
x
是形状的[batch_size, channels, height, width]
cond
是条件嵌入t_start
是开始时的采样步骤,orig
是我们正在绘制的潜在页面中的原始图像。mask
是保留原始图像的掩码。orig_noise
是要添加到原始图像的固定噪点。uncond_scale
是无条件指导量表。这用于uncond_cond
是空提示的条件嵌入97 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98 orig: Optional[torch.Tensor] = None,
99 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100 uncond_scale: float = 1.,
101 uncond_cond: Optional[torch.Tensor] = None,
102 ):
116 raise NotImplementedError()
118 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126 raise NotImplementedError()