Latent diffusion models use an auto-encoder to map between image space and latent space. The diffusion model works on the latent space, which makes it a lot easier to train. It is based on paper High-Resolution Image Synthesis with Latent Diffusion Models.
They use a pre-trained auto-encoder and train the diffusion U-Net on the latent space of the pre-trained auto-encoder.
For a simpler diffusion implementation refer to our DDPM implementation. We use same notations for , schedules, etc.
24from typing import List
25
26import torch
27import torch.nn as nn
28import torch.nn.functional
29
30from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder
31from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
32from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
This is an empty wrapper class around the U-Net. We keep this to have the same model structure as CompVis/stable-diffusion so that we do not have to map the checkpoint weights explicitly.
35class DiffusionWrapper(nn.Module):
43 def __init__(self, diffusion_model: UNetModel):
44 super().__init__()
45 self.diffusion_model = diffusion_model
47 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor):
48 return self.diffusion_model(x, time_steps, context)
This contains following components:
51class LatentDiffusion(nn.Module):
61 model: DiffusionWrapper
62 first_stage_model: Autoencoder
63 cond_stage_model: CLIPTextEmbedder
unet_model
is the U-Net that predicts noise , in latent space autoencoder
is the AutoEncoder clip_embedder
is the CLIP embeddings generator latent_scaling_factor
is the scaling factor for the latent space. The encodings of the autoencoder are scaled by this before feeding into the U-Net. n_steps
is the number of diffusion steps . linear_start
is the start of the schedule. linear_end
is the end of the schedule.65 def __init__(self,
66 unet_model: UNetModel,
67 autoencoder: Autoencoder,
68 clip_embedder: CLIPTextEmbedder,
69 latent_scaling_factor: float,
70 n_steps: int,
71 linear_start: float,
72 linear_end: float,
73 ):
85 super().__init__()
Wrap the U-Net to keep the same model structure as CompVis/stable-diffusion.
88 self.model = DiffusionWrapper(unet_model)
Auto-encoder and scaling factor
90 self.first_stage_model = autoencoder
91 self.latent_scaling_factor = latent_scaling_factor
93 self.cond_stage_model = clip_embedder
Number of steps
96 self.n_steps = n_steps
schedule
99 beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2
100 self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
102 alpha = 1. - beta
104 alpha_bar = torch.cumprod(alpha, dim=0)
105 self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
107 @property
108 def device(self):
112 return next(iter(self.model.parameters())).device
114 def get_text_conditioning(self, prompts: List[str]):
118 return self.cond_stage_model(prompts)
The encoder output is a distribution. We sample from that and multiply by the scaling factor.
120 def autoencoder_encode(self, image: torch.Tensor):
127 return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()
129 def autoencoder_decode(self, z: torch.Tensor):
135 return self.first_stage_model.decode(z / self.latent_scaling_factor)
Predict noise given the latent representation , time step , and the conditioning context .
137 def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
146 return self.model(x, t, context)