Denoising Diffusion Implicit Models (DDIM) Sampling

This implements DDIM sampling from the paper Denoising Diffusion Implicit Models

16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler

DDIM Sampler

This extends the DiffusionSampler base class.

DDPM samples images by repeatedly removing noise by sampling step by step using,

where is random noise, is a subsequence of of length , and

Note that, in DDIM paper refers to from DDPM.

26class DDIMSampler(DiffusionSampler):
52    model: LatentDiffusion
  • model is the model to predict noise
  • n_steps is the number of DDIM sampling steps,
  • ddim_discretize specifies how to extract from . It can be either uniform or quad .
  • ddim_eta is used to calculate . makes the sampling process deterministic.
54    def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63        super().__init__(model)

Number of steps,

65        self.n_steps = model.n_steps

Calculate to be uniformly distributed across

68        if ddim_discretize == 'uniform':
69            c = self.n_steps // n_steps
70            self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1

Calculate to be quadratically distributed across

72        elif ddim_discretize == 'quad':
73            self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74        else:
75            raise NotImplementedError(ddim_discretize)
76
77        with torch.no_grad():

Get

79            alpha_bar = self.model.alpha_bar

82            self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)

84            self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)

86            self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])

91            self.ddim_sigma = (ddim_eta *
92                               ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93                                (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)

96            self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5

Sampling Loop

  • shape is the shape of the generated images in the form [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • temperature is the noise temperature (random noise gets multiplied by this)
  • x_last is . If not provided random noise will be used.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
  • skip_steps is the number of time steps to skip . We start sampling from . And x_last is then .
98    @torch.no_grad()
99    def sample(self,
100               shape: List[int],
101               cond: torch.Tensor,
102               repeat_noise: bool = False,
103               temperature: float = 1.,
104               x_last: Optional[torch.Tensor] = None,
105               uncond_scale: float = 1.,
106               uncond_cond: Optional[torch.Tensor] = None,
107               skip_steps: int = 0,
108               ):

Get device and batch size

125        device = self.model.device
126        bs = shape[0]

Get

129        x = x_last if x_last is not None else torch.randn(shape, device=device)

Time steps to sample at

132        time_steps = np.flip(self.time_steps)[skip_steps:]
133
134        for i, step in monit.enum('Sample', time_steps):

Index in the list

136            index = len(time_steps) - i - 1

Time step

138            ts = x.new_full((bs,), step, dtype=torch.long)

Sample

141            x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142                                            repeat_noise=repeat_noise,
143                                            temperature=temperature,
144                                            uncond_scale=uncond_scale,
145                                            uncond_cond=uncond_cond)

Return

148        return x

Sample

  • x is of shape [batch_size, channels, height, width]
  • c is the conditional embeddings of shape [batch_size, emb_size]
  • t is of shape [batch_size]
  • step is the step as an integer
  • index is index in the list
  • repeat_noise specified whether the noise should be same for all samples in the batch
  • temperature is the noise temperature (random noise gets multiplied by this)
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
150    @torch.no_grad()
151    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152                 repeat_noise: bool = False,
153                 temperature: float = 1.,
154                 uncond_scale: float = 1.,
155                 uncond_cond: Optional[torch.Tensor] = None):

Get Error

172        e_t = self.get_eps(x, t, c,
173                           uncond_scale=uncond_scale,
174                           uncond_cond=uncond_cond)

Calculate and predicted

177        x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178                                                      temperature=temperature,
179                                                      repeat_noise=repeat_noise)

182        return x_prev, pred_x0, e_t

Sample given Error

184    def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185                               temperature: float,
186                               repeat_noise: bool):

192        alpha = self.ddim_alpha[index]

194        alpha_prev = self.ddim_alpha_prev[index]

196        sigma = self.ddim_sigma[index]

198        sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]

Current prediction for ,

202        pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)

Direction pointing to

205        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t

No noise is added, when

208        if sigma == 0.:
209            noise = 0.

If same noise is used for all samples in the batch

211        elif repeat_noise:
212            noise = torch.randn((1, *x.shape[1:]), device=x.device)

Different noise for each sample

214        else:
215            noise = torch.randn(x.shape, device=x.device)

Multiply noise by the temperature

218        noise = noise * temperature

227        x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise

230        return x_prev, pred_x0

Sample from

  • x0 is of shape [batch_size, channels, height, width]
  • index is the time step index
  • noise is the noise,
232    @torch.no_grad()
233    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):

Random noise, if noise is not specified

246        if noise is None:
247            noise = torch.randn_like(x0)

Sample from

252        return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise

Painting Loop

  • x is of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • t_start is the sampling step to start from,
  • orig is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation.
  • mask is the mask to keep the original image.
  • orig_noise is fixed noise to be added to the original image.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
254    @torch.no_grad()
255    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256              orig: Optional[torch.Tensor] = None,
257              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258              uncond_scale: float = 1.,
259              uncond_cond: Optional[torch.Tensor] = None,
260              ):

Get batch size

276        bs = x.shape[0]

Time steps to sample at

279        time_steps = np.flip(self.time_steps[:t_start])
280
281        for i, step in monit.enum('Paint', time_steps):

Index in the list

283            index = len(time_steps) - i - 1

Time step

285            ts = x.new_full((bs,), step, dtype=torch.long)

Sample

288            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289                                    uncond_scale=uncond_scale,
290                                    uncond_cond=uncond_cond)

Replace the masked area with original image

293            if orig is not None:

Get the for original image in latent space

295                orig_t = self.q_sample(orig, index, noise=orig_noise)

Replace the masked area

297                x = orig_t * mask + x * (1 - mask)

300        return x