Denoising Diffusion Probabilistic Models (DDPM)

Open In Colab Open In Comet

This is a PyTorch implementation/tutorial of the paper Denoising Diffusion Probabilistic Models.

In simple terms, we get an image from data and add noise step by step. Then We train a model to predict that noise at each step and use the model to generate images.

The following definitions and derivations show how this works. For details please refer to the paper.

Forward Process

The forward process adds noise to the data , for timesteps.

where is the variance schedule.

We can sample at any timestep with,

where and

Reverse Process

The reverse process removes noise starting at for time steps.

are the parameters we train.

Loss

We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.

The loss can be rewritten as follows.

is constant since we keep constant.

Computing

The forward process posterior conditioned by is,

The paper sets where is set to constants or .

Then,

For given noise using

This gives,

Re-parameterizing with a model to predict noise

where is a learned function that predicts given .

This gives,

That is, we are training to predict the noise.

Simplified loss

This minimizes when and for discarding the weighting in . Discarding the weights increase the weight given to higher (which have higher noise levels), therefore increasing the sample quality.

This file implements the loss calculation and a basic sampling method that we use to generate images during training.

Here is the UNet model that gives and training code. This file can generate samples and interpolations from a trained model.

163from typing import Tuple, Optional
164
165import torch
166import torch.nn.functional as F
167import torch.utils.data
168from torch import nn
169
170from labml_nn.diffusion.ddpm.utils import gather

Denoise Diffusion

173class DenoiseDiffusion:
  • eps_model is model
  • n_steps is
  • device is the device to place constants on
178    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
184        super().__init__()
185        self.eps_model = eps_model

Create linearly increasing variance schedule

188        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)

191        self.alpha = 1. - self.beta

193        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

195        self.n_steps = n_steps

197        self.sigma2 = self.beta

Get distribution

199    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

gather and compute

209        mean = gather(self.alpha_bar, t) ** 0.5 * x0

211        var = 1 - gather(self.alpha_bar, t)

213        return mean, var

Sample from

215    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):

225        if eps is None:
226            eps = torch.randn_like(x0)

get

229        mean, var = self.q_xt_x0(x0, t)

Sample from

231        return mean + (var ** 0.5) * eps

Sample from

233    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

247        eps_theta = self.eps_model(xt, t)

gather

249        alpha_bar = gather(self.alpha_bar, t)

251        alpha = gather(self.alpha, t)

253        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5

256        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)

258        var = gather(self.sigma2, t)

261        eps = torch.randn(xt.shape, device=xt.device)

Sample

263        return mean + (var ** .5) * eps

Simplified Loss

265    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

Get batch size

274        batch_size = x0.shape[0]

Get random for each sample in the batch

276        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)

279        if noise is None:
280            noise = torch.randn_like(x0)

Sample for

283        xt = self.q_sample(x0, t, eps=noise)

Get

285        eps_theta = self.eps_model(xt, t)

MSE loss

288        return F.mse_loss(noise, eps_theta)