18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion

Base class for sampling algorithms

25class DiffusionSampler:
29    model: LatentDiffusion
  • model is the model to predict noise
31    def __init__(self, model: LatentDiffusion):
35        super().__init__()

Set the model

37        self.model = model

Get number of steps the model was trained with

39        self.n_steps = model.n_steps

Get

  • x is of shape [batch_size, channels, height, width]
  • t is of shape [batch_size]
  • c is the conditional embeddings of shape [batch_size, emb_size]
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
41    def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42                uncond_scale: float, uncond_cond: Optional[torch.Tensor]):

When the scale

55        if uncond_cond is None or uncond_scale == 1.:
56            return self.model(x, t, c)

Duplicate and

59        x_in = torch.cat([x] * 2)
60        t_in = torch.cat([t] * 2)

Concatenated and

62        c_in = torch.cat([uncond_cond, c])

Get and

64        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)

Calculate

67        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)

70        return e_t

Sampling Loop

  • shape is the shape of the generated images in the form [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • temperature is the noise temperature (random noise gets multiplied by this)
  • x_last is . If not provided random noise will be used.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
  • skip_steps is the number of time steps to skip.
72    def sample(self,
73               shape: List[int],
74               cond: torch.Tensor,
75               repeat_noise: bool = False,
76               temperature: float = 1.,
77               x_last: Optional[torch.Tensor] = None,
78               uncond_scale: float = 1.,
79               uncond_cond: Optional[torch.Tensor] = None,
80               skip_steps: int = 0,
81               ):
95        raise NotImplementedError()

Painting Loop

  • x is of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings
  • t_start is the sampling step to start from,
  • orig is the original image in latent page which we are in paining.
  • mask is the mask to keep the original image.
  • orig_noise is fixed noise to be added to the original image.
  • uncond_scale is the unconditional guidance scale . This is used for
  • uncond_cond is the conditional embedding for empty prompt
97    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98              orig: Optional[torch.Tensor] = None,
99              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100              uncond_scale: float = 1.,
101              uncond_cond: Optional[torch.Tensor] = None,
102              ):
116        raise NotImplementedError()

Sample from

  • x0 is of shape [batch_size, channels, height, width]
  • index is the time step index
  • noise is the noise,
118    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126        raise NotImplementedError()