# 去噪扩散概率模型 (DDPM)

## 损失

### 简化损失

162from typing import Tuple, Optional
163
164import torch
165import torch.nn.functional as F
166import torch.utils.data
167from torch import nn
168
169from labml_nn.diffusion.ddpm.utils import gather

## 降噪扩散

172class DenoiseDiffusion:
• eps_model模特
• n_steps
• device 是用来放置常量的设备
177    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
183        super().__init__()
184        self.eps_model = eps_model

187        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
190        self.alpha = 1. - self.beta
192        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
194        self.n_steps = n_steps
196        self.sigma2 = self.beta

#### 获取分发

198    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

208        mean = gather(self.alpha_bar, t) ** 0.5 * x0
210        var = 1 - gather(self.alpha_bar, t)
212        return mean, var

#### 样本来自

214    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
224        if eps is None:
225            eps = torch.randn_like(x0)

228        mean, var = self.q_xt_x0(x0, t)

230        return mean + (var ** 0.5) * eps

#### 样本来自

232    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
246        eps_theta = self.eps_model(xt, t)
248        alpha_bar = gather(self.alpha_bar, t)
250        alpha = gather(self.alpha, t)
252        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
255        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
257        var = gather(self.sigma2, t)
260        eps = torch.randn(xt.shape, device=xt.device)

262        return mean + (var ** .5) * eps

#### 简化损失

264    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

273        batch_size = x0.shape[0]

275        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
278        if noise is None:
279            noise = torch.randn_like(x0)

282        xt = self.q_sample(x0, t, eps=noise)

284        eps_theta = self.eps_model(xt, t)

MSE 亏损

287        return F.mse_loss(noise, eps_theta)