这是《去噪扩散概率模型》论文的 PyTorch 实现/教程。
简而言之,我们从数据中获取图像并逐步添加噪点。然后,我们训练一个模型来预测每个步骤的噪声,并使用该模型生成图像。
以下定义和派生说明了其工作原理。详情请参阅论文。
在时间步长内,转发过程会给数据增加噪音。
方差计划在哪里。
我们可以随时采样,
在哪里和
相反的过程会从四个时间步长开始消除噪音。
是我们训练的参数。
我们根据负对数概率优化 ELBO(来自简森不等式)。
损失可以改写如下。
是恒定的,因为我们保持不变。
后验的前向过程是,
论文将其中设置为常量或.
然后,
对于给定的噪音,使用
这给了,
使用模型重新参数化以预测噪声
其中是预测给定值的学习函数。
这给了,
也就是说,我们正在训练预测噪音。
这样可以最大限度地减少放弃权重的时间和时间。丢弃权重会增加给出更高的权重(噪声等级更高),从而提高样本质量。
该文件实现了损失计算和基本采样方法,我们在训练期间使用该方法生成图像。
162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather
172class DenoiseDiffusion:
eps_model
是模特n_steps
是device
是用来放置常量的设备177 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183 super().__init__()
184 self.eps_model = eps_model
创建线性增加的差异计划
187 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
190 self.alpha = 1. - self.beta
192 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
194 self.n_steps = n_steps
196 self.sigma2 = self.beta
198 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
210 var = 1 - gather(self.alpha_bar, t)
212 return mean, var
214 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
224 if eps is None:
225 eps = torch.randn_like(x0)
得到
228 mean, var = self.q_xt_x0(x0, t)
样本来自
230 return mean + (var ** 0.5) * eps
232 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
246 eps_theta = self.eps_model(xt, t)
250 alpha = gather(self.alpha, t)
252 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
255 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
257 var = gather(self.sigma2, t)
260 eps = torch.randn(xt.shape, device=xt.device)
样本
262 return mean + (var ** .5) * eps
264 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
获取批次大小
273 batch_size = x0.shape[0]
批次中每个样本的随机获取
275 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
278 if noise is None:
279 noise = torch.randn_like(x0)
的样本
282 xt = self.q_sample(x0, t, eps=noise)
得到
284 eps_theta = self.eps_model(xt, t)
MSE 亏损
287 return F.mse_loss(noise, eps_theta)