用于稳定扩散的自动编码器

这实现了用于在图像空间和潜在空间之间进行映射的自动编码器模型。

我们保持了 compvis/Stable-Difusi on 的模型定义和命名不变,这样我们就可以直接加载检查点。

18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn

自动编码器

它由编码器和解码器模块组成。

25class Autoencoder(nn.Module):
  • encoder 是编码器
  • decoder 是解码器
  • emb_channels 是量化嵌入空间中的维数
  • z_channels 是嵌入空间中的通道数
32    def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
39        super().__init__()
40        self.encoder = encoder
41        self.decoder = decoder

从嵌入空间到量化嵌入空间矩的卷积到映射(均值和对数方差)

44        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)

卷积将从量化嵌入空间映射回嵌入空间

47        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)

将图像编码为潜在表示

  • img 是带有形状的图像张量[batch_size, img_channels, img_height, img_width]
49    def encode(self, img: torch.Tensor) -> 'GaussianDistribution':

获取带有形状的嵌入物[batch_size, z_channels * 2, z_height, z_height]

56        z = self.encoder(img)

获取量化嵌入空间中的瞬间

58        moments = self.quant_conv(z)

返回分布

60        return GaussianDistribution(moments)

从潜在表现中解码图像

  • z 是带有形状的潜在表示形式[batch_size, emb_channels, z_height, z_height]
62    def decode(self, z: torch.Tensor):

从量化表示映射到嵌入空间

69        z = self.post_quant_conv(z)

解码形状的图像[batch_size, channels, height, width]

71        return self.decoder(z)

编码器模块

74class Encoder(nn.Module):
  • channels 是第一个卷积层中的通道数
  • channel_multipliers 是后续区组中信道数量的乘法因子
  • n_resnet_blocks 是每种分辨率下的 resnet 层数
  • in_channels 是图像中的通道数
  • z_channels 是嵌入空间中的通道数
79    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80                 in_channels: int, z_channels: int):
89        super().__init__()

不同分辨率的区块数。每个顶层方块的结尾处分辨率减半

93        n_resolutions = len(channel_multipliers)

将图像映射到的初始卷积层channels

96        self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)

每个顶级区块中的频道数

99        channels_list = [m * channels for m in [1] + channel_multipliers]

顶级区块列表

102        self.down = nn.ModuleList()

创建顶级区块

104        for i in range(n_resolutions):

每个顶级区块由多个 ResNet 模块和向下采样组成

106            resnet_blocks = nn.ModuleList()

添加 ResNet 区块

108            for _ in range(n_resnet_blocks):
109                resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110                channels = channels_list[i + 1]

顶级区块

112            down = nn.Module()
113            down.block = resnet_blocks

在每个顶级区块的末尾处向下采样(最后一个区块除外)

115            if i != n_resolutions - 1:
116                down.downsample = DownSample(channels)
117            else:
118                down.downsample = nn.Identity()

120            self.down.append(down)

最后一个值得注意的 ResNet 封锁

123        self.mid = nn.Module()
124        self.mid.block_1 = ResnetBlock(channels, channels)
125        self.mid.attn_1 = AttnBlock(channels)
126        self.mid.block_2 = ResnetBlock(channels, channels)

卷积映射到嵌入空间

129        self.norm_out = normalization(channels)
130        self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
  • img 是带有形状的图像张量[batch_size, img_channels, img_height, img_width]
132    def forward(self, img: torch.Tensor):

channels 使用初始卷积映射到

138        x = self.conv_in(img)

顶级区块

141        for down in self.down:

ResNet 区块

143            for block in down.block:
144                x = block(x)

向下采样

146            x = down.downsample(x)

最后一个值得注意的 ResNet 封锁

149        x = self.mid.block_1(x)
150        x = self.mid.attn_1(x)
151        x = self.mid.block_2(x)

归一化并映射到嵌入空间

154        x = self.norm_out(x)
155        x = swish(x)
156        x = self.conv_out(x)

159        return x

解码器模块

162class Decoder(nn.Module):
  • channels 是最终卷积层中的通道数
  • channel_multipliers 是前面区块中信道数的乘法因子,顺序相反
  • n_resnet_blocks 是每种分辨率下的 resnet 层数
  • out_channels 是图像中的通道数
  • z_channels 是嵌入空间中的通道数
167    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168                 out_channels: int, z_channels: int):
177        super().__init__()

不同分辨率的区块数。每个顶层方块的结尾处分辨率减半

181        num_resolutions = len(channel_multipliers)

每个顶级块中的通道数,按相反顺序排列

184        channels_list = [m * channels for m in channel_multipliers]

顶级区块中的频道数

187        channels = channels_list[-1]

将嵌入空间映射到的初始卷积层channels

190        self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)

ResNet 要注意封锁

193        self.mid = nn.Module()
194        self.mid.block_1 = ResnetBlock(channels, channels)
195        self.mid.attn_1 = AttnBlock(channels)
196        self.mid.block_2 = ResnetBlock(channels, channels)

顶级区块列表

199        self.up = nn.ModuleList()

创建顶级区块

201        for i in reversed(range(num_resolutions)):

每个顶级区块由多个 ResNet 模块和向上采样组成

203            resnet_blocks = nn.ModuleList()

添加 ResNet 区块

205            for _ in range(n_resnet_blocks + 1):
206                resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207                channels = channels_list[i]

顶级区块

209            up = nn.Module()
210            up.block = resnet_blocks

在每个顶级区块的结尾处向上采样(第一个除外)

212            if i != 0:
213                up.upsample = UpSample(channels)
214            else:
215                up.upsample = nn.Identity()

预先设置以与检查点保持一致

217            self.up.insert(0, up)

使用卷积映射到图像空间

220        self.norm_out = normalization(channels)
221        self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
  • z 是带有形状的嵌入张量[batch_size, z_channels, z_height, z_height]
223    def forward(self, z: torch.Tensor):

channels 使用初始卷积映射到

229        h = self.conv_in(z)

ResNet 要注意封锁

232        h = self.mid.block_1(h)
233        h = self.mid.attn_1(h)
234        h = self.mid.block_2(h)

顶级区块

237        for up in reversed(self.up):

ResNet 区块

239            for block in up.block:
240                h = block(h)

向上采样

242            h = up.upsample(h)

归一化并映射到图像空间

245        h = self.norm_out(h)
246        h = swish(h)
247        img = self.conv_out(h)

250        return img

高斯分布

253class GaussianDistribution:
  • parameters 是形状嵌入的方差的均值和对数[batch_size, z_channels * 2, z_height, z_height]
258    def __init__(self, parameters: torch.Tensor):

分割均值和方差对数

264        self.mean, log_var = torch.chunk(parameters, 2, dim=1)

限制方差日志

266        self.log_var = torch.clamp(log_var, -30.0, 20.0)

计算标准差

268        self.std = torch.exp(0.5 * self.log_var)
270    def sample(self):

来自分布的样本

272        return self.mean + self.std * torch.randn_like(self.std)

注意方块

275class AttnBlock(nn.Module):
  • channels 是频道数
280    def __init__(self, channels: int):
284        super().__init__()

群组标准化

286        self.norm = normalization(channels)

查询、键和值映射

288        self.q = nn.Conv2d(channels, channels, 1)
289        self.k = nn.Conv2d(channels, channels, 1)
290        self.v = nn.Conv2d(channels, channels, 1)

最终卷积层

292        self.proj_out = nn.Conv2d(channels, channels, 1)

注意力缩放系数

294        self.scale = channels ** -0.5
  • x 是形状的张量[batch_size, channels, height, width]
296    def forward(self, x: torch.Tensor):

标准化x

301        x_norm = self.norm(x)

获取查询、键和向量嵌入

303        q = self.q(x_norm)
304        k = self.k(x_norm)
305        v = self.v(x_norm)

重塑为查询,键嵌入和向量嵌入从[batch_size, channels, height, width][batch_size, channels, height * width]

310        b, c, h, w = q.shape
311        q = q.view(b, c, h * w)
312        k = k.view(b, c, h * w)
313        v = v.view(b, c, h * w)

计算

316        attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317        attn = F.softmax(attn, dim=2)

计算

320        out = torch.einsum('bij,bcj->bci', attn, v)

重塑回原状[batch_size, channels, height, width]

323        out = out.view(b, c, h, w)

最终卷积层

325        out = self.proj_out(out)

添加剩余连接

328        return x + out

向上采样层

331class UpSample(nn.Module):
  • channels 是频道数
335    def __init__(self, channels: int):
339        super().__init__()

卷积映射

341        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x 是带有形状的输入要素图[batch_size, channels, height, width]
343    def forward(self, x: torch.Tensor):

按系数向上采样

348        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

应用卷积

350        return self.conv(x)

向下采样层

353class DownSample(nn.Module):
  • channels 是频道数
357    def __init__(self, channels: int):
361        super().__init__()

卷积,步长为向下采样的系数为

363        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
  • x 是带有形状的输入要素图[batch_size, channels, height, width]
365    def forward(self, x: torch.Tensor):

添加内边距

370        x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)

应用卷积

372        return self.conv(x)

ResNet 区块

375class ResnetBlock(nn.Module):
  • in_channels 是输入中的通道数
  • out_channels 是输出中的通道数
379    def __init__(self, in_channels: int, out_channels: int):
384        super().__init__()

第一个归一化和卷积层

386        self.norm1 = normalization(in_channels)
387        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)

第二个归一化和卷积层

389        self.norm2 = normalization(out_channels)
390        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)

in_channels 到剩余连接的out_channels 映射层

392        if in_channels != out_channels:
393            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394        else:
395            self.nin_shortcut = nn.Identity()
  • x 是带有形状的输入要素图[batch_size, channels, height, width]
397    def forward(self, x: torch.Tensor):
402        h = x

第一个归一化和卷积层

405        h = self.norm1(h)
406        h = swish(h)
407        h = self.conv1(h)

第二个归一化和卷积层

410        h = self.norm2(h)
411        h = swish(h)
412        h = self.conv2(h)

映射并添加残差

415        return self.nin_shortcut(x) + h

Swish 激活

418def swish(x: torch.Tensor):
424    return x * torch.sigmoid(x)

群组标准化

这是一个辅助函数,具有固定数量的组和eps

427def normalization(channels: int):
433    return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)