This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.
In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.
The following definitions and derivations show how this works. For details please refer to the paper.
The forward process adds noise to the data , for timesteps.
where is the variance schedule.
We can sample at any timestep with,
where and
The reverse process removes noise starting at for time steps.
are the parameters we train.
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
The loss can be rewritten as follows.
is constant since we keep constant.
The forward process posterior conditioned by is,
The paper sets where is set to constants or .
Then,
For given noise using
This gives,
Re-parameterizing with a model to predict noise
where is a learned function that predicts given .
This gives,
That is, we are training to predict the noise.
This minimizes when and for discarding the weighting in . Discarding the weights increase the weight given to higher (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during training.
Here is the UNet model that gives and training code. This file can generate samples and interpolations from a trained model.
163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather
173class DenoiseDiffusion:
eps_model
is model n_steps
is device
is the device to place constants on178 def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
184 super().__init__()
185 self.eps_model = eps_model
Create linearly increasing variance schedule
188 self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
191 self.alpha = 1. - self.beta
193 self.alpha_bar = torch.cumprod(self.alpha, dim=0)
195 self.n_steps = n_steps
197 self.sigma2 = self.beta
199 def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
211 var = 1 - gather(self.alpha_bar, t)
213 return mean, var
215 def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
225 if eps is None:
226 eps = torch.randn_like(x0)
get
229 mean, var = self.q_xt_x0(x0, t)
Sample from
231 return mean + (var ** 0.5) * eps
233 def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
247 eps_theta = self.eps_model(xt, t)
251 alpha = gather(self.alpha, t)
253 eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
256 mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
258 var = gather(self.sigma2, t)
261 eps = torch.randn(xt.shape, device=xt.device)
Sample
263 return mean + (var ** .5) * eps
265 def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
Get batch size
274 batch_size = x0.shape[0]
Get random for each sample in the batch
276 t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
279 if noise is None:
280 noise = torch.randn_like(x0)
Sample for
283 xt = self.q_sample(x0, t, eps=noise)
Get
285 eps_theta = self.eps_model(xt, t)
MSE loss
288 return F.mse_loss(noise, eps_theta)