潜在拡散モデルでは、オートエンコーダーを使用して画像空間と潜在空間をマッピングします。拡散モデルは潜在空間で機能するため、トレーニングがはるかに簡単になります。これは、潜在拡散モデルを用いた論文の高解像度画像合成に基づいています
。事前にトレーニングされたオートエンコーダーを使用し、事前トレーニング済みのオートエンコーダーの潜在空間で拡散 U-Net をトレーニングします。
より単純な拡散実装については、DDPM 実装を参照してください。スケジュールなどにも同じ表記を使います
。24from typing import List
25
26import torch
27import torch.nn as nn
28
29from labml_nn.diffusion.stable_diffusion.model.autoencoder import Autoencoder
30from labml_nn.diffusion.stable_diffusion.model.clip_embedder import CLIPTextEmbedder
31from labml_nn.diffusion.stable_diffusion.model.unet import UNetModel
これは U-Net 周辺の空のラッパークラスです。チェックポイントの重みを明示的にマッピングする必要がないように、これを compVis/Stable-Diffusion と同じモデル構造にしておきます
。34class DiffusionWrapper(nn.Module):
42 def __init__(self, diffusion_model: UNetModel):
43 super().__init__()
44 self.diffusion_model = diffusion_model
46 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, context: torch.Tensor):
47 return self.diffusion_model(x, time_steps, context)
50class LatentDiffusion(nn.Module):
60 model: DiffusionWrapper
61 first_stage_model: Autoencoder
62 cond_stage_model: CLIPTextEmbedder
unet_model
潜伏空間のノイズを予測するU-Netですautoencoder
はオートエンコーダですclip_embedder
CLIP 埋め込みジェネレータですlatent_scaling_factor
は潜在空間のスケーリング係数です。オートエンコーダのエンコーディングは、U-Netにフィードする前にこれによってスケーリングされますn_steps
は拡散ステップの数です。linear_start
スケジュールの始まりです。linear_end
スケジュールは終了です。64 def __init__(self,
65 unet_model: UNetModel,
66 autoencoder: Autoencoder,
67 clip_embedder: CLIPTextEmbedder,
68 latent_scaling_factor: float,
69 n_steps: int,
70 linear_start: float,
71 linear_end: float,
72 ):
84 super().__init__()
CompVis/Stable-Diffusionと同じモデル構造を保つために U-Net をラップしてください。
87 self.model = DiffusionWrapper(unet_model)
オートエンコーダとスケーリングファクター
89 self.first_stage_model = autoencoder
90 self.latent_scaling_factor = latent_scaling_factor
92 self.cond_stage_model = clip_embedder
ステップ数
95 self.n_steps = n_steps
スケジュール
98 beta = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_steps, dtype=torch.float64) ** 2
99 self.beta = nn.Parameter(beta.to(torch.float32), requires_grad=False)
101 alpha = 1. - beta
103 alpha_bar = torch.cumprod(alpha, dim=0)
104 self.alpha_bar = nn.Parameter(alpha_bar.to(torch.float32), requires_grad=False)
106 @property
107 def device(self):
111 return next(iter(self.model.parameters())).device
113 def get_text_conditioning(self, prompts: List[str]):
117 return self.cond_stage_model(prompts)
119 def autoencoder_encode(self, image: torch.Tensor):
126 return self.latent_scaling_factor * self.first_stage_model.encode(image).sample()
128 def autoencoder_decode(self, z: torch.Tensor):
134 return self.first_stage_model.decode(z / self.latent_scaling_factor)
136 def forward(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor):
145 return self.model(x, t, context)