去噪扩散概率模型 (DDPM)

Open In Colab

这是《去噪扩散概率模型》论文的 PyTorch 实现/教程。

简而言之,我们从数据中获取图像并逐步添加噪点。然后,我们训练一个模型来预测每个步骤的噪声,并使用该模型生成图像。

以下定义和派生说明了其工作原理。详情请参阅论文

转发进程

时间步长内,转发过程会给数据增加噪音。

方差计划在哪里

我们可以随时采样

在哪里

反向处理

相反的过程会从四个时间步长开始消除噪音。

是我们训练的参数。

损失

我们根据负对数概率优化 ELBO(来自简森不等式)。

损失可以改写如下。

是恒定的,因为我们保持不变。

计算

后验的前向过程是,

论文将其中设置为常量.

然后,

对于给定的噪音,使用

这给了,

使用模型重新参数化以预测噪声

其中是预测给定值的学习函数

这给了,

也就是说,我们正在训练预测噪音。

简化损失

这样可以最大限度地减少弃权重的时间时间。丢弃权重会增加给出更高的权重(噪声等级更高),从而提高样本质量。

该文件实现了损失计算和基本采样方法,我们在训练期间使用该方法生成图像。

这是提供训练代码UNet 模型此文件可以从经过训练的模型生成样本和插值。

162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather

降噪扩散

172class DenoiseDiffusion:
  • eps_model模特
  • n_steps
  • device 是用来放置常量的设备
177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183        super().__init__()
184        self.eps_model = eps_model

创建线性增加的差异计划

187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

190        self.alpha = 1. - self.beta

192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

194        self.n_steps = n_steps

196        self.sigma2 = self.beta

获取分发

198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

收集和计算

208        mean = gather(self.alpha_bar, t) ** 0.5 * x0

210        var = 1 - gather(self.alpha_bar, t)

212        return mean, var

样本来自

214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

224        if eps is None:
225            eps = torch.randn_like(x0)

得到

228        mean, var = self.q_xt_x0(x0, t)

样本来自

230        return mean + (var ** 0.5) * eps

样本来自

232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

246        eps_theta = self.eps_model(xt, t)

收集

248        alpha_bar = gather(self.alpha_bar, t)

250        alpha = gather(self.alpha, t)

252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

257        var = gather(self.sigma2, t)

260        eps = torch.randn(xt.shape, device=xt.device)

样本

262        return mean + (var ** .5) * eps

简化损失

264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

获取批次大小

273        batch_size = x0.shape[0]

批次中每个样本的随机获取

275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

278        if noise is None:
279            noise = torch.randn_like(x0)

的样本

282        xt = self.q_sample(x0, t, eps=noise)

得到

284        eps_theta = self.eps_model(xt, t)

MSE 亏损

287        return F.mse_loss(noise, eps_theta)