This implements DDIM sampling from the paper Denoising Diffusion Implicit Models
16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSamplerThis extends the DiffusionSampler
 base class.
DDIM samples images by repeatedly removing noise by sampling step by step using,
where is random noise, is a subsequence of of length , and
Note that, in DDIM paper refers to from DDPM.
26class DDIMSampler(DiffusionSampler):52    model: LatentDiffusionmodel
  is the model to predict noise  n_steps
  is the number of DDIM sampling steps,  ddim_discretize
  specifies how to extract  from .  It can be either uniform
 or quad
. ddim_eta
  is  used to calculate .  makes the  sampling process deterministic.54    def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):63        super().__init__(model)Number of steps,
65        self.n_steps = model.n_stepsCalculate to be uniformly distributed across
68        if ddim_discretize == 'uniform':
69            c = self.n_steps // n_steps
70            self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1Calculate to be quadratically distributed across
72        elif ddim_discretize == 'quad':
73            self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74        else:
75            raise NotImplementedError(ddim_discretize)
76
77        with torch.no_grad():Get
79            alpha_bar = self.model.alpha_bar82            self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)84            self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)86            self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])91            self.ddim_sigma = (ddim_eta *
92                               ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93                                (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)96            self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5shape
  is the shape of the generated images in the  form [batch_size, channels, height, width]
 cond
  is the conditional embeddings  temperature
  is the noise temperature (random noise gets multiplied by this) x_last
  is . If not provided random noise will be used. uncond_scale
  is the unconditional guidance scale . This is used for   uncond_cond
  is the conditional embedding for empty prompt  skip_steps
  is the number of time steps to skip . We start sampling from .  And x_last
 is then .98    @torch.no_grad()
99    def sample(self,
100               shape: List[int],
101               cond: torch.Tensor,
102               repeat_noise: bool = False,
103               temperature: float = 1.,
104               x_last: Optional[torch.Tensor] = None,
105               uncond_scale: float = 1.,
106               uncond_cond: Optional[torch.Tensor] = None,
107               skip_steps: int = 0,
108               ):Get device and batch size
125        device = self.model.device
126        bs = shape[0]Get
129        x = x_last if x_last is not None else torch.randn(shape, device=device)Time steps to sample at
132        time_steps = np.flip(self.time_steps)[skip_steps:]
133
134        for i, step in monit.enum('Sample', time_steps):Index in the list
136            index = len(time_steps) - i - 1Time step
138            ts = x.new_full((bs,), step, dtype=torch.long)Sample
141            x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142                                            repeat_noise=repeat_noise,
143                                            temperature=temperature,
144                                            uncond_scale=uncond_scale,
145                                            uncond_cond=uncond_cond)Return
148        return xx
  is  of shape [batch_size, channels, height, width]
 c
  is the conditional embeddings  of shape [batch_size, emb_size]
 t
  is  of shape [batch_size]
 step
  is the step  as an integer index
  is index  in the list  repeat_noise
  specified whether the noise should be same for all samples in the batch temperature
  is the noise temperature (random noise gets multiplied by this) uncond_scale
  is the unconditional guidance scale . This is used for   uncond_cond
  is the conditional embedding for empty prompt 150    @torch.no_grad()
151    def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152                 repeat_noise: bool = False,
153                 temperature: float = 1.,
154                 uncond_scale: float = 1.,
155                 uncond_cond: Optional[torch.Tensor] = None):Get
172        e_t = self.get_eps(x, t, c,
173                           uncond_scale=uncond_scale,
174                           uncond_cond=uncond_cond)Calculate and predicted
177        x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178                                                      temperature=temperature,
179                                                      repeat_noise=repeat_noise)182        return x_prev, pred_x0, e_t184    def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185                               temperature: float,
186                               repeat_noise: bool):192        alpha = self.ddim_alpha[index]194        alpha_prev = self.ddim_alpha_prev[index]196        sigma = self.ddim_sigma[index]198        sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]Current prediction for ,
202        pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)Direction pointing to
205        dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_tNo noise is added, when
208        if sigma == 0.:
209            noise = 0.If same noise is used for all samples in the batch
211        elif repeat_noise:
212            noise = torch.randn((1, *x.shape[1:]), device=x.device)Different noise for each sample
214        else:
215            noise = torch.randn(x.shape, device=x.device)Multiply noise by the temperature
218        noise = noise * temperature227        x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise230        return x_prev, pred_x0
x0
  is  of shape [batch_size, channels, height, width]
 index
  is the time step  index  noise
  is the noise, 232    @torch.no_grad()
233    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):Random noise, if noise is not specified
246        if noise is None:
247            noise = torch.randn_like(x0)Sample from
252        return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noisex
  is  of shape [batch_size, channels, height, width]
 cond
  is the conditional embeddings  t_start
  is the sampling step to start from,  orig
  is the original image in latent page which we are in paining.  If this is not provided, it'll be an image to image transformation. mask
  is the mask to keep the original image. orig_noise
  is fixed noise to be added to the original image. uncond_scale
  is the unconditional guidance scale . This is used for   uncond_cond
  is the conditional embedding for empty prompt 254    @torch.no_grad()
255    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256              orig: Optional[torch.Tensor] = None,
257              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258              uncond_scale: float = 1.,
259              uncond_cond: Optional[torch.Tensor] = None,
260              ):Get batch size
276        bs = x.shape[0]Time steps to sample at
279        time_steps = np.flip(self.time_steps[:t_start])
280
281        for i, step in monit.enum('Paint', time_steps):Index in the list
283            index = len(time_steps) - i - 1Time step
285            ts = x.new_full((bs,), step, dtype=torch.long)Sample
288            x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289                                    uncond_scale=uncond_scale,
290                                    uncond_cond=uncond_cond)Replace the masked area with original image
293            if orig is not None:Get the for original image in latent space
295                orig_t = self.q_sample(orig, index, noise=orig_noise)Replace the masked area
297                x = orig_t * mask + x * (1 - mask)300        return x