This implements DDIM sampling from the paper Denoising Diffusion Implicit Models
16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
This extends the DiffusionSampler
base class.
DDIM samples images by repeatedly removing noise by sampling step by step using,
where is random noise, is a subsequence of of length , and
Note that, in DDIM paper refers to from DDPM.
26class DDIMSampler(DiffusionSampler):
52 model: LatentDiffusion
model
is the model to predict noise n_steps
is the number of DDIM sampling steps, ddim_discretize
specifies how to extract from . It can be either uniform
or quad
. ddim_eta
is used to calculate . makes the sampling process deterministic.54 def __init__(self, model: LatentDiffusion, n_steps: int, ddim_discretize: str = "uniform", ddim_eta: float = 0.):
63 super().__init__(model)
Number of steps,
65 self.n_steps = model.n_steps
Calculate to be uniformly distributed across
68 if ddim_discretize == 'uniform':
69 c = self.n_steps // n_steps
70 self.time_steps = np.asarray(list(range(0, self.n_steps, c))) + 1
Calculate to be quadratically distributed across
72 elif ddim_discretize == 'quad':
73 self.time_steps = ((np.linspace(0, np.sqrt(self.n_steps * .8), n_steps)) ** 2).astype(int) + 1
74 else:
75 raise NotImplementedError(ddim_discretize)
76
77 with torch.no_grad():
Get
79 alpha_bar = self.model.alpha_bar
82 self.ddim_alpha = alpha_bar[self.time_steps].clone().to(torch.float32)
84 self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
86 self.ddim_alpha_prev = torch.cat([alpha_bar[0:1], alpha_bar[self.time_steps[:-1]]])
91 self.ddim_sigma = (ddim_eta *
92 ((1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
93 (1 - self.ddim_alpha / self.ddim_alpha_prev)) ** .5)
96 self.ddim_sqrt_one_minus_alpha = (1. - self.ddim_alpha) ** .5
shape
is the shape of the generated images in the form [batch_size, channels, height, width]
cond
is the conditional embeddings temperature
is the noise temperature (random noise gets multiplied by this) x_last
is . If not provided random noise will be used. uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt skip_steps
is the number of time steps to skip . We start sampling from . And x_last
is then .98 @torch.no_grad()
99 def sample(self,
100 shape: List[int],
101 cond: torch.Tensor,
102 repeat_noise: bool = False,
103 temperature: float = 1.,
104 x_last: Optional[torch.Tensor] = None,
105 uncond_scale: float = 1.,
106 uncond_cond: Optional[torch.Tensor] = None,
107 skip_steps: int = 0,
108 ):
Get device and batch size
125 device = self.model.device
126 bs = shape[0]
Get
129 x = x_last if x_last is not None else torch.randn(shape, device=device)
Time steps to sample at
132 time_steps = np.flip(self.time_steps)[skip_steps:]
133
134 for i, step in monit.enum('Sample', time_steps):
Index in the list
136 index = len(time_steps) - i - 1
Time step
138 ts = x.new_full((bs,), step, dtype=torch.long)
Sample
141 x, pred_x0, e_t = self.p_sample(x, cond, ts, step, index=index,
142 repeat_noise=repeat_noise,
143 temperature=temperature,
144 uncond_scale=uncond_scale,
145 uncond_cond=uncond_cond)
Return
148 return x
x
is of shape [batch_size, channels, height, width]
c
is the conditional embeddings of shape [batch_size, emb_size]
t
is of shape [batch_size]
step
is the step as an integer index
is index in the list repeat_noise
specified whether the noise should be same for all samples in the batch temperature
is the noise temperature (random noise gets multiplied by this) uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt 150 @torch.no_grad()
151 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int, index: int, *,
152 repeat_noise: bool = False,
153 temperature: float = 1.,
154 uncond_scale: float = 1.,
155 uncond_cond: Optional[torch.Tensor] = None):
Get
172 e_t = self.get_eps(x, t, c,
173 uncond_scale=uncond_scale,
174 uncond_cond=uncond_cond)
Calculate and predicted
177 x_prev, pred_x0 = self.get_x_prev_and_pred_x0(e_t, index, x,
178 temperature=temperature,
179 repeat_noise=repeat_noise)
182 return x_prev, pred_x0, e_t
184 def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, *,
185 temperature: float,
186 repeat_noise: bool):
192 alpha = self.ddim_alpha[index]
194 alpha_prev = self.ddim_alpha_prev[index]
196 sigma = self.ddim_sigma[index]
198 sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
Current prediction for ,
202 pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
Direction pointing to
205 dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
No noise is added, when
208 if sigma == 0.:
209 noise = 0.
If same noise is used for all samples in the batch
211 elif repeat_noise:
212 noise = torch.randn((1, *x.shape[1:]), device=x.device)
Different noise for each sample
214 else:
215 noise = torch.randn(x.shape, device=x.device)
Multiply noise by the temperature
218 noise = noise * temperature
227 x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
230 return x_prev, pred_x0
x0
is of shape [batch_size, channels, height, width]
index
is the time step index noise
is the noise, 232 @torch.no_grad()
233 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
Random noise, if noise is not specified
246 if noise is None:
247 noise = torch.randn_like(x0)
Sample from
252 return self.ddim_alpha_sqrt[index] * x0 + self.ddim_sqrt_one_minus_alpha[index] * noise
x
is of shape [batch_size, channels, height, width]
cond
is the conditional embeddings t_start
is the sampling step to start from, orig
is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation. mask
is the mask to keep the original image. orig_noise
is fixed noise to be added to the original image. uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt 254 @torch.no_grad()
255 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
256 orig: Optional[torch.Tensor] = None,
257 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
258 uncond_scale: float = 1.,
259 uncond_cond: Optional[torch.Tensor] = None,
260 ):
Get batch size
276 bs = x.shape[0]
Time steps to sample at
279 time_steps = np.flip(self.time_steps[:t_start])
280
281 for i, step in monit.enum('Paint', time_steps):
Index in the list
283 index = len(time_steps) - i - 1
Time step
285 ts = x.new_full((bs,), step, dtype=torch.long)
Sample
288 x, _, _ = self.p_sample(x, cond, ts, step, index=index,
289 uncond_scale=uncond_scale,
290 uncond_cond=uncond_cond)
Replace the masked area with original image
293 if orig is not None:
Get the for original image in latent space
295 orig_t = self.q_sample(orig, index, noise=orig_noise)
Replace the masked area
297 x = orig_t * mask + x * (1 - mask)
300 return x