这实现了用于在图像空间和潜在空间之间进行映射的自动编码器模型。
我们保持了 compvis/Stable-Difusi on 的模型定义和命名不变,这样我们就可以直接加载检查点。
18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn
25class Autoencoder(nn.Module):
encoder
是编码器decoder
是解码器emb_channels
是量化嵌入空间中的维数z_channels
是嵌入空间中的通道数32 def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
39 super().__init__()
40 self.encoder = encoder
41 self.decoder = decoder
从嵌入空间到量化嵌入空间矩的卷积到映射(均值和对数方差)
44 self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
卷积将从量化嵌入空间映射回嵌入空间
47 self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
49 def encode(self, img: torch.Tensor) -> 'GaussianDistribution':
获取带有形状的嵌入物[batch_size, z_channels * 2, z_height, z_height]
56 z = self.encoder(img)
获取量化嵌入空间中的瞬间
58 moments = self.quant_conv(z)
返回分布
60 return GaussianDistribution(moments)
62 def decode(self, z: torch.Tensor):
从量化表示映射到嵌入空间
69 z = self.post_quant_conv(z)
解码形状的图像[batch_size, channels, height, width]
71 return self.decoder(z)
74class Encoder(nn.Module):
channels
是第一个卷积层中的通道数channel_multipliers
是后续区组中信道数量的乘法因子n_resnet_blocks
是每种分辨率下的 resnet 层数in_channels
是图像中的通道数z_channels
是嵌入空间中的通道数79 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80 in_channels: int, z_channels: int):
89 super().__init__()
不同分辨率的区块数。每个顶层方块的结尾处分辨率减半
93 n_resolutions = len(channel_multipliers)
将图像映射到的初始卷积层channels
96 self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
每个顶级区块中的频道数
99 channels_list = [m * channels for m in [1] + channel_multipliers]
顶级区块列表
102 self.down = nn.ModuleList()
创建顶级区块
104 for i in range(n_resolutions):
每个顶级区块由多个 ResNet 模块和向下采样组成
106 resnet_blocks = nn.ModuleList()
添加 ResNet 区块
108 for _ in range(n_resnet_blocks):
109 resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110 channels = channels_list[i + 1]
顶级区块
112 down = nn.Module()
113 down.block = resnet_blocks
在每个顶级区块的末尾处向下采样(最后一个区块除外)
115 if i != n_resolutions - 1:
116 down.downsample = DownSample(channels)
117 else:
118 down.downsample = nn.Identity()
120 self.down.append(down)
最后一个值得注意的 ResNet 封锁
123 self.mid = nn.Module()
124 self.mid.block_1 = ResnetBlock(channels, channels)
125 self.mid.attn_1 = AttnBlock(channels)
126 self.mid.block_2 = ResnetBlock(channels, channels)
用卷积映射到嵌入空间
129 self.norm_out = normalization(channels)
130 self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
img
是带有形状的图像张量[batch_size, img_channels, img_height, img_width]
132 def forward(self, img: torch.Tensor):
channels
使用初始卷积映射到
138 x = self.conv_in(img)
顶级区块
141 for down in self.down:
ResNet 区块
143 for block in down.block:
144 x = block(x)
向下采样
146 x = down.downsample(x)
最后一个值得注意的 ResNet 封锁
149 x = self.mid.block_1(x)
150 x = self.mid.attn_1(x)
151 x = self.mid.block_2(x)
归一化并映射到嵌入空间
154 x = self.norm_out(x)
155 x = swish(x)
156 x = self.conv_out(x)
159 return x
162class Decoder(nn.Module):
channels
是最终卷积层中的通道数channel_multipliers
是前面区块中信道数的乘法因子,顺序相反n_resnet_blocks
是每种分辨率下的 resnet 层数out_channels
是图像中的通道数z_channels
是嵌入空间中的通道数167 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168 out_channels: int, z_channels: int):
177 super().__init__()
不同分辨率的区块数。每个顶层方块的结尾处分辨率减半
181 num_resolutions = len(channel_multipliers)
每个顶级块中的通道数,按相反顺序排列
184 channels_list = [m * channels for m in channel_multipliers]
顶级区块中的频道数
187 channels = channels_list[-1]
将嵌入空间映射到的初始卷积层channels
190 self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
ResNet 要注意封锁
193 self.mid = nn.Module()
194 self.mid.block_1 = ResnetBlock(channels, channels)
195 self.mid.attn_1 = AttnBlock(channels)
196 self.mid.block_2 = ResnetBlock(channels, channels)
顶级区块列表
199 self.up = nn.ModuleList()
创建顶级区块
201 for i in reversed(range(num_resolutions)):
每个顶级区块由多个 ResNet 模块和向上采样组成
203 resnet_blocks = nn.ModuleList()
添加 ResNet 区块
205 for _ in range(n_resnet_blocks + 1):
206 resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207 channels = channels_list[i]
顶级区块
209 up = nn.Module()
210 up.block = resnet_blocks
在每个顶级区块的结尾处向上采样(第一个除外)
212 if i != 0:
213 up.upsample = UpSample(channels)
214 else:
215 up.upsample = nn.Identity()
预先设置以与检查点保持一致
217 self.up.insert(0, up)
使用卷积映射到图像空间
220 self.norm_out = normalization(channels)
221 self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
z
是带有形状的嵌入张量[batch_size, z_channels, z_height, z_height]
223 def forward(self, z: torch.Tensor):
channels
使用初始卷积映射到
229 h = self.conv_in(z)
ResNet 要注意封锁
232 h = self.mid.block_1(h)
233 h = self.mid.attn_1(h)
234 h = self.mid.block_2(h)
顶级区块
237 for up in reversed(self.up):
ResNet 区块
239 for block in up.block:
240 h = block(h)
向上采样
242 h = up.upsample(h)
归一化并映射到图像空间
245 h = self.norm_out(h)
246 h = swish(h)
247 img = self.conv_out(h)
250 return img
253class GaussianDistribution:
parameters
是形状嵌入的方差的均值和对数[batch_size, z_channels * 2, z_height, z_height]
258 def __init__(self, parameters: torch.Tensor):
分割均值和方差对数
264 self.mean, log_var = torch.chunk(parameters, 2, dim=1)
限制方差日志
266 self.log_var = torch.clamp(log_var, -30.0, 20.0)
计算标准差
268 self.std = torch.exp(0.5 * self.log_var)
270 def sample(self):
来自分布的样本
272 return self.mean + self.std * torch.randn_like(self.std)
275class AttnBlock(nn.Module):
channels
是频道数280 def __init__(self, channels: int):
284 super().__init__()
群组标准化
286 self.norm = normalization(channels)
查询、键和值映射
288 self.q = nn.Conv2d(channels, channels, 1)
289 self.k = nn.Conv2d(channels, channels, 1)
290 self.v = nn.Conv2d(channels, channels, 1)
最终卷积层
292 self.proj_out = nn.Conv2d(channels, channels, 1)
注意力缩放系数
294 self.scale = channels ** -0.5
x
是形状的张量[batch_size, channels, height, width]
296 def forward(self, x: torch.Tensor):
标准化x
301 x_norm = self.norm(x)
获取查询、键和向量嵌入
303 q = self.q(x_norm)
304 k = self.k(x_norm)
305 v = self.v(x_norm)
重塑为查询,键嵌入和向量嵌入从[batch_size, channels, height, width]
为[batch_size, channels, height * width]
310 b, c, h, w = q.shape
311 q = q.view(b, c, h * w)
312 k = k.view(b, c, h * w)
313 v = v.view(b, c, h * w)
计算
316 attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317 attn = F.softmax(attn, dim=2)
计算
320 out = torch.einsum('bij,bcj->bci', attn, v)
重塑回原状[batch_size, channels, height, width]
323 out = out.view(b, c, h, w)
最终卷积层
325 out = self.proj_out(out)
添加剩余连接
328 return x + out
331class UpSample(nn.Module):
channels
是频道数335 def __init__(self, channels: int):
339 super().__init__()
卷积映射
341 self.conv = nn.Conv2d(channels, channels, 3, padding=1)
x
是带有形状的输入要素图[batch_size, channels, height, width]
343 def forward(self, x: torch.Tensor):
按系数向上采样
348 x = F.interpolate(x, scale_factor=2.0, mode="nearest")
应用卷积
350 return self.conv(x)
353class DownSample(nn.Module):
channels
是频道数357 def __init__(self, channels: int):
361 super().__init__()
卷积,步长为向下采样的系数为
363 self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
x
是带有形状的输入要素图[batch_size, channels, height, width]
365 def forward(self, x: torch.Tensor):
添加内边距
370 x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
应用卷积
372 return self.conv(x)
375class ResnetBlock(nn.Module):
in_channels
是输入中的通道数out_channels
是输出中的通道数379 def __init__(self, in_channels: int, out_channels: int):
384 super().__init__()
第一个归一化和卷积层
386 self.norm1 = normalization(in_channels)
387 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
第二个归一化和卷积层
389 self.norm2 = normalization(out_channels)
390 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
in_channels
到剩余连接的out_channels
映射层
392 if in_channels != out_channels:
393 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394 else:
395 self.nin_shortcut = nn.Identity()
x
是带有形状的输入要素图[batch_size, channels, height, width]
397 def forward(self, x: torch.Tensor):
402 h = x
第一个归一化和卷积层
405 h = self.norm1(h)
406 h = swish(h)
407 h = self.conv1(h)
第二个归一化和卷积层
410 h = self.norm2(h)
411 h = swish(h)
412 h = self.conv2(h)
映射并添加残差
415 return self.nin_shortcut(x) + h
418def swish(x: torch.Tensor):
424 return x * torch.sigmoid(x)
427def normalization(channels: int):
433 return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)