This implements the auto-encoder model used to map between image space and latent space.
We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.
18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn
25class Autoencoder(nn.Module):
encoder
is the encoder decoder
is the decoder emb_channels
is the number of dimensions in the quantized embedding space z_channels
is the number of channels in the embedding space32 def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
39 super().__init__()
40 self.encoder = encoder
41 self.decoder = decoder
Convolution to map from embedding space to quantized embedding space moments (mean and log variance)
44 self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
Convolution to map from quantized embedding space back to embedding space
47 self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
img
is the image tensor with shape [batch_size, img_channels, img_height, img_width]
49 def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
Get embeddings with shape [batch_size, z_channels * 2, z_height, z_height]
56 z = self.encoder(img)
Get the moments in the quantized embedding space
58 moments = self.quant_conv(z)
Return the distribution
60 return GaussianDistribution(moments)
z
is the latent representation with shape [batch_size, emb_channels, z_height, z_height]
62 def decode(self, z: torch.Tensor):
Map to embedding space from the quantized representation
69 z = self.post_quant_conv(z)
Decode the image of shape [batch_size, channels, height, width]
71 return self.decoder(z)
74class Encoder(nn.Module):
channels
is the number of channels in the first convolution layer channel_multipliers
are the multiplicative factors for the number of channels in the subsequent blocks n_resnet_blocks
is the number of resnet layers at each resolution in_channels
is the number of channels in the image z_channels
is the number of channels in the embedding space79 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80 in_channels: int, z_channels: int):
89 super().__init__()
Number of blocks of different resolutions. The resolution is halved at the end each top level block
93 n_resolutions = len(channel_multipliers)
Initial convolution layer that maps the image to channels
96 self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
Number of channels in each top level block
99 channels_list = [m * channels for m in [1] + channel_multipliers]
List of top-level blocks
102 self.down = nn.ModuleList()
Create top-level blocks
104 for i in range(n_resolutions):
Each top level block consists of multiple ResNet Blocks and down-sampling
106 resnet_blocks = nn.ModuleList()
Add ResNet Blocks
108 for _ in range(n_resnet_blocks):
109 resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110 channels = channels_list[i + 1]
Top-level block
112 down = nn.Module()
113 down.block = resnet_blocks
Down-sampling at the end of each top level block except the last
115 if i != n_resolutions - 1:
116 down.downsample = DownSample(channels)
117 else:
118 down.downsample = nn.Identity()
120 self.down.append(down)
Final ResNet blocks with attention
123 self.mid = nn.Module()
124 self.mid.block_1 = ResnetBlock(channels, channels)
125 self.mid.attn_1 = AttnBlock(channels)
126 self.mid.block_2 = ResnetBlock(channels, channels)
Map to embedding space with a convolution
129 self.norm_out = normalization(channels)
130 self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
img
is the image tensor with shape [batch_size, img_channels, img_height, img_width]
132 def forward(self, img: torch.Tensor):
Map to channels
with the initial convolution
138 x = self.conv_in(img)
Top-level blocks
141 for down in self.down:
ResNet Blocks
143 for block in down.block:
144 x = block(x)
Down-sampling
146 x = down.downsample(x)
Final ResNet blocks with attention
149 x = self.mid.block_1(x)
150 x = self.mid.attn_1(x)
151 x = self.mid.block_2(x)
Normalize and map to embedding space
154 x = self.norm_out(x)
155 x = swish(x)
156 x = self.conv_out(x)
159 return x
162class Decoder(nn.Module):
channels
is the number of channels in the final convolution layer channel_multipliers
are the multiplicative factors for the number of channels in the previous blocks, in reverse order n_resnet_blocks
is the number of resnet layers at each resolution out_channels
is the number of channels in the image z_channels
is the number of channels in the embedding space167 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168 out_channels: int, z_channels: int):
177 super().__init__()
Number of blocks of different resolutions. The resolution is halved at the end each top level block
181 num_resolutions = len(channel_multipliers)
Number of channels in each top level block, in the reverse order
184 channels_list = [m * channels for m in channel_multipliers]
Number of channels in the top-level block
187 channels = channels_list[-1]
Initial convolution layer that maps the embedding space to channels
190 self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
ResNet blocks with attention
193 self.mid = nn.Module()
194 self.mid.block_1 = ResnetBlock(channels, channels)
195 self.mid.attn_1 = AttnBlock(channels)
196 self.mid.block_2 = ResnetBlock(channels, channels)
List of top-level blocks
199 self.up = nn.ModuleList()
Create top-level blocks
201 for i in reversed(range(num_resolutions)):
Each top level block consists of multiple ResNet Blocks and up-sampling
203 resnet_blocks = nn.ModuleList()
Add ResNet Blocks
205 for _ in range(n_resnet_blocks + 1):
206 resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207 channels = channels_list[i]
Top-level block
209 up = nn.Module()
210 up.block = resnet_blocks
Up-sampling at the end of each top level block except the first
212 if i != 0:
213 up.upsample = UpSample(channels)
214 else:
215 up.upsample = nn.Identity()
Prepend to be consistent with the checkpoint
217 self.up.insert(0, up)
Map to image space with a convolution
220 self.norm_out = normalization(channels)
221 self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
z
is the embedding tensor with shape [batch_size, z_channels, z_height, z_height]
223 def forward(self, z: torch.Tensor):
Map to channels
with the initial convolution
229 h = self.conv_in(z)
ResNet blocks with attention
232 h = self.mid.block_1(h)
233 h = self.mid.attn_1(h)
234 h = self.mid.block_2(h)
Top-level blocks
237 for up in reversed(self.up):
ResNet Blocks
239 for block in up.block:
240 h = block(h)
Up-sampling
242 h = up.upsample(h)
Normalize and map to image space
245 h = self.norm_out(h)
246 h = swish(h)
247 img = self.conv_out(h)
250 return img
253class GaussianDistribution:
parameters
are the means and log of variances of the embedding of shape [batch_size, z_channels * 2, z_height, z_height]
258 def __init__(self, parameters: torch.Tensor):
Split mean and log of variance
264 self.mean, log_var = torch.chunk(parameters, 2, dim=1)
Clamp the log of variances
266 self.log_var = torch.clamp(log_var, -30.0, 20.0)
Calculate standard deviation
268 self.std = torch.exp(0.5 * self.log_var)
270 def sample(self):
Sample from the distribution
272 return self.mean + self.std * torch.randn_like(self.std)
275class AttnBlock(nn.Module):
channels
is the number of channels280 def __init__(self, channels: int):
284 super().__init__()
Group normalization
286 self.norm = normalization(channels)
Query, key and value mappings
288 self.q = nn.Conv2d(channels, channels, 1)
289 self.k = nn.Conv2d(channels, channels, 1)
290 self.v = nn.Conv2d(channels, channels, 1)
Final convolution layer
292 self.proj_out = nn.Conv2d(channels, channels, 1)
Attention scaling factor
294 self.scale = channels ** -0.5
x
is the tensor of shape [batch_size, channels, height, width]
296 def forward(self, x: torch.Tensor):
Normalize x
301 x_norm = self.norm(x)
Get query, key and vector embeddings
303 q = self.q(x_norm)
304 k = self.k(x_norm)
305 v = self.v(x_norm)
Reshape to query, key and vector embeedings from [batch_size, channels, height, width]
to [batch_size, channels, height * width]
310 b, c, h, w = q.shape
311 q = q.view(b, c, h * w)
312 k = k.view(b, c, h * w)
313 v = v.view(b, c, h * w)
Compute
316 attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317 attn = F.softmax(attn, dim=2)
Compute
320 out = torch.einsum('bij,bcj->bci', attn, v)
Reshape back to [batch_size, channels, height, width]
323 out = out.view(b, c, h, w)
Final convolution layer
325 out = self.proj_out(out)
Add residual connection
328 return x + out
331class UpSample(nn.Module):
channels
is the number of channels335 def __init__(self, channels: int):
339 super().__init__()
convolution mapping
341 self.conv = nn.Conv2d(channels, channels, 3, padding=1)
x
is the input feature map with shape [batch_size, channels, height, width]
343 def forward(self, x: torch.Tensor):
Up-sample by a factor of
348 x = F.interpolate(x, scale_factor=2.0, mode="nearest")
Apply convolution
350 return self.conv(x)
353class DownSample(nn.Module):
channels
is the number of channels357 def __init__(self, channels: int):
361 super().__init__()
convolution with stride length of to down-sample by a factor of
363 self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
x
is the input feature map with shape [batch_size, channels, height, width]
365 def forward(self, x: torch.Tensor):
Add padding
370 x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
Apply convolution
372 return self.conv(x)
375class ResnetBlock(nn.Module):
in_channels
is the number of channels in the input out_channels
is the number of channels in the output379 def __init__(self, in_channels: int, out_channels: int):
384 super().__init__()
First normalization and convolution layer
386 self.norm1 = normalization(in_channels)
387 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
Second normalization and convolution layer
389 self.norm2 = normalization(out_channels)
390 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
in_channels
to out_channels
mapping layer for residual connection
392 if in_channels != out_channels:
393 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394 else:
395 self.nin_shortcut = nn.Identity()
x
is the input feature map with shape [batch_size, channels, height, width]
397 def forward(self, x: torch.Tensor):
402 h = x
First normalization and convolution layer
405 h = self.norm1(h)
406 h = swish(h)
407 h = self.conv1(h)
Second normalization and convolution layer
410 h = self.norm2(h)
411 h = swish(h)
412 h = self.conv2(h)
Map and add residual
415 return self.nin_shortcut(x) + h
418def swish(x: torch.Tensor):
424 return x * torch.sigmoid(x)
427def normalization(channels: int):
433 return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)