For a simpler DDPM implementation refer to our DDPM implementation. We use same notations for , schedules, etc.
16from typing import Optional, List
17
18import numpy as np
19import torch
20
21from labml import monit
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
23from labml_nn.diffusion.stable_diffusion.sampler import DiffusionSampler
This extends the DiffusionSampler
base class.
DDPM samples images by repeatedly removing noise by sampling step by step from ,
26class DDPMSampler(DiffusionSampler):
49 model: LatentDiffusion
model
is the model to predict noise 51 def __init__(self, model: LatentDiffusion):
55 super().__init__(model)
Sampling steps
58 self.time_steps = np.asarray(list(range(self.n_steps)))
59
60 with torch.no_grad():
62 alpha_bar = self.model.alpha_bar
schedule
64 beta = self.model.beta
66 alpha_bar_prev = torch.cat([alpha_bar.new_tensor([1.]), alpha_bar[:-1]])
69 self.sqrt_alpha_bar = alpha_bar ** .5
71 self.sqrt_1m_alpha_bar = (1. - alpha_bar) ** .5
73 self.sqrt_recip_alpha_bar = alpha_bar ** -.5
75 self.sqrt_recip_m1_alpha_bar = (1 / alpha_bar - 1) ** .5
78 variance = beta * (1. - alpha_bar_prev) / (1. - alpha_bar)
Clamped log of
80 self.log_var = torch.log(torch.clamp(variance, min=1e-20))
82 self.mean_x0_coef = beta * (alpha_bar_prev ** .5) / (1. - alpha_bar)
84 self.mean_xt_coef = (1. - alpha_bar_prev) * ((1 - beta) ** 0.5) / (1. - alpha_bar)
shape
is the shape of the generated images in the form [batch_size, channels, height, width]
cond
is the conditional embeddings temperature
is the noise temperature (random noise gets multiplied by this) x_last
is . If not provided random noise will be used. uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt skip_steps
is the number of time steps to skip . We start sampling from . And x_last
is then .86 @torch.no_grad()
87 def sample(self,
88 shape: List[int],
89 cond: torch.Tensor,
90 repeat_noise: bool = False,
91 temperature: float = 1.,
92 x_last: Optional[torch.Tensor] = None,
93 uncond_scale: float = 1.,
94 uncond_cond: Optional[torch.Tensor] = None,
95 skip_steps: int = 0,
96 ):
Get device and batch size
113 device = self.model.device
114 bs = shape[0]
Get
117 x = x_last if x_last is not None else torch.randn(shape, device=device)
Time steps to sample at
120 time_steps = np.flip(self.time_steps)[skip_steps:]
Sampling loop
123 for step in monit.iterate('Sample', time_steps):
Time step
125 ts = x.new_full((bs,), step, dtype=torch.long)
Sample
128 x, pred_x0, e_t = self.p_sample(x, cond, ts, step,
129 repeat_noise=repeat_noise,
130 temperature=temperature,
131 uncond_scale=uncond_scale,
132 uncond_cond=uncond_cond)
Return
135 return x
x
is of shape [batch_size, channels, height, width]
c
is the conditional embeddings of shape [batch_size, emb_size]
t
is of shape [batch_size]
step
is the step as an integer :repeat_noise: specified whether the noise should be same for all samples in the batch temperature
is the noise temperature (random noise gets multiplied by this) uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt 137 @torch.no_grad()
138 def p_sample(self, x: torch.Tensor, c: torch.Tensor, t: torch.Tensor, step: int,
139 repeat_noise: bool = False,
140 temperature: float = 1.,
141 uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None):
Get
157 e_t = self.get_eps(x, t, c,
158 uncond_scale=uncond_scale,
159 uncond_cond=uncond_cond)
Get batch size
162 bs = x.shape[0]
165 sqrt_recip_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_alpha_bar[step])
167 sqrt_recip_m1_alpha_bar = x.new_full((bs, 1, 1, 1), self.sqrt_recip_m1_alpha_bar[step])
172 x0 = sqrt_recip_alpha_bar * x - sqrt_recip_m1_alpha_bar * e_t
175 mean_x0_coef = x.new_full((bs, 1, 1, 1), self.mean_x0_coef[step])
177 mean_xt_coef = x.new_full((bs, 1, 1, 1), self.mean_xt_coef[step])
183 mean = mean_x0_coef * x0 + mean_xt_coef * x
185 log_var = x.new_full((bs, 1, 1, 1), self.log_var[step])
Do not add noise when (final step sampling process). Note that step
is 0
when )
189 if step == 0:
190 noise = 0
If same noise is used for all samples in the batch
192 elif repeat_noise:
193 noise = torch.randn((1, *x.shape[1:]))
Different noise for each sample
195 else:
196 noise = torch.randn(x.shape)
Multiply noise by the temperature
199 noise = noise * temperature
204 x_prev = mean + (0.5 * log_var).exp() * noise
207 return x_prev, x0, e_t
x0
is of shape [batch_size, channels, height, width]
index
is the time step index noise
is the noise, 209 @torch.no_grad()
210 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
Random noise, if noise is not specified
222 if noise is None:
223 noise = torch.randn_like(x0)
Sample from
226 return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise