潜在扩散模型

潜在扩散模型使用自动编码器在图像空间和潜在空间之间进行映射。扩散模型适用于潜在空间,这使得训练变得容易得多。它基于带有潜在扩散模型的纸质高分辨率图像合成

它们使用预训练的自动编码器,在预训练的自动编码器的潜在空间上训练扩散 U-Net。

有关更简单的扩散实现,请参阅我们的 DDPM 实现。我们对时间表等使用相同的符号。

24from typing import List
25
26import torch
27import torch.nn as nn
28
29from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder
30from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
31from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel

这是围绕 U-Net 的空包装类。我们保持它与 compVIS/Stable- Difusion 相同的模型结构,这样我们就不必明确地映射检查点权重

34class DiffusionWrapper(nn.Module):
42    def __init__(self, diffusion_model: UNetModel):
43        super().__init__()
44        self.diffusion_model = diffusion_model
46    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor):
47        return self.diffusion_model(x, time_steps, context)

潜在扩散模型

它包含以下组件:

50class LatentDiffusion(nn.Module):
60    model: DiffusionWrapper
61    first_stage_model: Autoencoder
62    cond_stage_model: CLIPTextEmbedder
  • unet_model 是预测潜在空间中噪声U-Ne t
  • autoencoder自动编码器
  • clip_embedderCLIP 嵌入生成器
  • latent_scaling_factor 是潜在空间的缩放系数。在馈入 U-Net 之前,自动编码器的编码会按此进行缩放。
  • n_steps 是扩散步骤的数量
  • linear_start时间表的开始。
  • linear_end时间表的结束。
64    def __init__(self,
65                 unet_model: UNetModel,
66                 autoencoder: Autoencoder,
67                 clip_embedder: CLIPTextEmbedder,
68                 latent_scaling_factor: float,
69                 n_steps: int,
70                 linear_start: float,
71                 linear_end: float,
72                 ):
84        super().__init__()

封装 U-Net 以保持与 compVIS/Stable- Difusion 相同的模型结构。

87        self.model = DiffusionWrapper(unet_model)

自动编码器和缩放系数

89        self.first_stage_model = autoencoder
90        self.latent_scaling_factor = latent_scaling_factor
92        self.cond_stage_model = clip_embedder

步数

95        self.n_steps = n_steps

时间表

98        beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2
99        self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)

101        alpha = 1. - beta

103        alpha_bar = torch.cumprod(alpha, dim=0)
104        self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)

获取设备模型

106    @property
107    def device(self):
111        return next(iter(self.model.parameters())).device

获取 CLIP 嵌入以获取文本提示列表

113    def get_text_conditioning(self, prompts: List[str]):
117        return self.cond_stage_model(prompts)

获取图像的缩放潜在空间表示

编码器输出是分布式。我们从中取样并乘以缩放系数。

119    def autoencoder_encode(self, image: torch.Tensor):
126        return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()

从潜在表示中获取图像

我们按缩放系数向下缩放,然后解码。

128    def autoencoder_decode(self, z: torch.Tensor):
134        return self.first_stage_model.decode(z / self.latent_scaling_factor)

预测噪音

根据潜在表示、时间步长和条件环境预测噪声

136    def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
145        return self.model(x, t, context)