# Sampling algorithms for stable diffusion

We have implemented the following sampling algorithms:

18from typing import Optional, List
19
20import torch
21
22from labml_nn.diffusion.stable_diffusion.latent_diffusion import LatentDiffusion

## Base class for sampling algorithms

25class DiffusionSampler:
29    model: LatentDiffusion
• model is the model to predict noise
31    def __init__(self, model: LatentDiffusion):
35        super().__init__()

Set the model

37        self.model = model

Get number of steps the model was trained with

39        self.n_steps = model.n_steps

## Get

• x is of shape [batch_size, channels, height, width]
• t is of shape [batch_size]
• c is the conditional embeddings of shape [batch_size, emb_size]
• uncond_scale is the unconditional guidance scale . This is used for
• uncond_cond is the conditional embedding for empty prompt
41    def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
42                uncond_scale: float, uncond_cond: Optional[torch.Tensor]):

When the scale

55        if uncond_cond is None or uncond_scale == 1.:
56            return self.model(x, t, c)

Duplicate and

59        x_in = torch.cat([x] * 2)
60        t_in = torch.cat([t] * 2)

Concatenated and

62        c_in = torch.cat([uncond_cond, c])

Get and

64        e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)

Calculate

67        e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
70        return e_t

### Sampling Loop

• shape is the shape of the generated images in the form [batch_size, channels, height, width]
• cond is the conditional embeddings
• temperature is the noise temperature (random noise gets multiplied by this)
• x_last is . If not provided random noise will be used.
• uncond_scale is the unconditional guidance scale . This is used for
• uncond_cond is the conditional embedding for empty prompt
• skip_steps is the number of time steps to skip.
72    def sample(self,
73               shape: List[int],
74               cond: torch.Tensor,
75               repeat_noise: bool = False,
76               temperature: float = 1.,
77               x_last: Optional[torch.Tensor] = None,
78               uncond_scale: float = 1.,
79               uncond_cond: Optional[torch.Tensor] = None,
80               skip_steps: int = 0,
81               ):
95        raise NotImplementedError()

### Painting Loop

• x is of shape [batch_size, channels, height, width]
• cond is the conditional embeddings
• t_start is the sampling step to start from,
• orig is the original image in latent page which we are in paining.
• mask is the mask to keep the original image.
• orig_noise is fixed noise to be added to the original image.
• uncond_scale is the unconditional guidance scale . This is used for
• uncond_cond is the conditional embedding for empty prompt
97    def paint(self, x: torch.Tensor, cond: torch.Tensor, t_start: int, *,
98              orig: Optional[torch.Tensor] = None,
99              mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None,
100              uncond_scale: float = 1.,
101              uncond_cond: Optional[torch.Tensor] = None,
102              ):
116        raise NotImplementedError()

### Sample from

• x0 is of shape [batch_size, channels, height, width]
• index is the time step index
• noise is the noise,
118    def q_sample(self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None):
126        raise NotImplementedError()