We have implemented the following sampling algorithms:
18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion
25class DiffusionSampler:
29 model: LatentDiffusion
model
is the model to predict noise 31 def __init__(self, model: LatentDiffusion):
35 super().__init__()
Set the model
37 self.model = model
Get number of steps the model was trained with
39 self.n_steps = model.n_steps
x
is of shape [batch_size, channels, height, width]
t
is of shape [batch_size]
c
is the conditional embeddings of shape [batch_size, emb_size]
uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt 41 def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42 uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
When the scale
55 if uncond_cond is None or uncond_scale == 1.:
56 return self.model(x, t, c)
Duplicate and
59 x_in = torch.cat([x] * 2)
60 t_in = torch.cat([t] * 2)
Concatenated and
62 c_in = torch.cat([uncond_cond, c])
Get and
64 e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
Calculate
67 e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
70 return e_t
shape
is the shape of the generated images in the form [batch_size, channels, height, width]
cond
is the conditional embeddings temperature
is the noise temperature (random noise gets multiplied by this) x_last
is . If not provided random noise will be used. uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt skip_steps
is the number of time steps to skip.72 def sample(self,
73 shape: List[int],
74 cond: torch.Tensor,
75 repeat_noise: bool = False,
76 temperature: float = 1.,
77 x_last: Optional[torch.Tensor] = None,
78 uncond_scale: float = 1.,
79 uncond_cond: Optional[torch.Tensor] = None,
80 skip_steps: int = 0,
81 ):
95 raise NotImplementedError()
x
is of shape [batch_size, channels, height, width]
cond
is the conditional embeddings t_start
is the sampling step to start from, orig
is the original image in latent page which we are in paining. mask
is the mask to keep the original image. orig_noise
is fixed noise to be added to the original image. uncond_scale
is the unconditional guidance scale . This is used for uncond_cond
is the conditional embedding for empty prompt 97 def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98 orig: Optional[torch.Tensor] = None,
99 mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100 uncond_scale: float = 1.,
101 uncond_cond: Optional[torch.Tensor] = None,
102 ):
116 raise NotImplementedError()
x0
is of shape [batch_size, channels, height, width]
index
is the time step index noise
is the noise, 118 def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126 raise NotImplementedError()