安定した拡散を実現するオートエンコーダー

これは、画像空間と潜在空間のマッピングに使用されるオートエンコーダモデルを実装しています。

チェックポイントを直接読み込めるように、CompVis/Stable-Diffusionからモデル定義と命名を変更していません

18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn

オートエンコーダ

これはエンコーダモジュールとデコーダモジュールで構成されています。

25class Autoencoder(nn.Module):
  • encoder はエンコーダです
  • decoder デコーダです
  • emb_channels は量子化された埋め込み空間の次元数です
  • z_channels は埋め込みスペースのチャンネル数です
32    def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
39        super().__init__()
40        self.encoder = encoder
41        self.decoder = decoder

埋め込み空間から量子化された埋め込み空間のモーメントにマッピングするための畳み込み (平均と対数分散)

44        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)

量子化された埋め込み空間から埋め込み空間にマッピングする畳み込み

47        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)

画像を潜在表現にエンコード

  • img 形状のあるイメージテンソルです [batch_size, img_channels, img_height, img_width]
49    def encode(self, img: torch.Tensor) -> 'GaussianDistribution':

図形付きの埋め込みをする [batch_size, z_channels * 2, z_height, z_height]

56        z = self.encoder(img)

量子化された埋め込み空間のモーメントを取得

58        moments = self.quant_conv(z)

ディストリビューションを返す

60        return GaussianDistribution(moments)

潜在表現から画像をデコード

  • z 形を使った潜在表現です [batch_size, emb_channels, z_height, z_height]
62    def decode(self, z: torch.Tensor):

量子化された表現から埋め込み空間にマッピング

69        z = self.post_quant_conv(z)

形状の画像をデコード [batch_size, channels, height, width]

71        return self.decoder(z)

エンコーダモジュール

74class Encoder(nn.Module):
  • channels は最初の畳み込み層のチャネル数です
  • channel_multipliers は後続のブロックのチャンネル数の乗数です
  • n_resnet_blocks は各解像度での再ネット層の数です
  • in_channels は画像内のチャンネル数
  • z_channels は埋め込みスペースのチャンネル数です
79    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80                 in_channels: int, z_channels: int):
89        super().__init__()

解像度の異なるブロック数。解像度は、各トップレベルブロックの最後で半分になります

93        n_resolutions = len(channel_multipliers)

画像をマップする最初の畳み込みレイヤー channels

96        self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)

各トップレベルブロックのチャンネル数

99        channels_list = [m * channels for m in [1] + channel_multipliers]

トップレベルブロックのリスト

102        self.down = nn.ModuleList()

トップレベルブロックを作成

104        for i in range(n_resolutions):

各トップレベルブロックは複数のResNetブロックとダウンサンプリングで構成されています

106            resnet_blocks = nn.ModuleList()

ResNet ブロックを追加

108            for _ in range(n_resnet_blocks):
109                resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110                channels = channels_list[i + 1]

トップレベルブロック

112            down = nn.Module()
113            down.block = resnet_blocks

最後を除く各トップレベルブロックの最後でのダウンサンプリング

115            if i != n_resolutions - 1:
116                down.downsample = DownSample(channels)
117            else:
118                down.downsample = nn.Identity()

120            self.down.append(down)

最後の ResNet ブロックには注意が必要です。

123        self.mid = nn.Module()
124        self.mid.block_1 = ResnetBlock(channels, channels)
125        self.mid.attn_1 = AttnBlock(channels)
126        self.mid.block_2 = ResnetBlock(channels, channels)

畳み込みによる埋め込み空間にマッピング

129        self.norm_out = normalization(channels)
130        self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
  • img 形状のあるイメージテンソルです [batch_size, img_channels, img_height, img_width]
132    def forward(self, img: torch.Tensor):

channels 最初の畳み込みでにマッピング

138        x = self.conv_in(img)

トップレベルブロック

141        for down in self.down:

ResNet ブロック

143            for block in down.block:
144                x = block(x)

ダウンサンプリング

146            x = down.downsample(x)

最後の ResNet ブロックには注意が必要です。

149        x = self.mid.block_1(x)
150        x = self.mid.attn_1(x)
151        x = self.mid.block_2(x)

正規化して埋め込みスペースにマッピング

154        x = self.norm_out(x)
155        x = swish(x)
156        x = self.conv_out(x)

159        return x

デコーダモジュール

162class Decoder(nn.Module):
  • channels は最後の畳み込み層のチャネル数です
  • channel_multipliers 前のブロックのチャンネル数の乗算係数 (逆順)
  • n_resnet_blocks は各解像度での再ネット層の数です
  • out_channels は画像内のチャンネル数
  • z_channels は埋め込みスペースのチャンネル数です
167    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168                 out_channels: int, z_channels: int):
177        super().__init__()

解像度の異なるブロック数。解像度は、各トップレベルブロックの最後で半分になります

181        num_resolutions = len(channel_multipliers)

各最上位ブロック内のチャンネル数 (逆順)

184        channels_list = [m * channels for m in channel_multipliers]

最上位ブロック内のチャンネル数

187        channels = channels_list[-1]

埋め込みスペースをマッピングする最初の畳み込みレイヤー channels

190        self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)

ResNet ブロックには注意が必要

193        self.mid = nn.Module()
194        self.mid.block_1 = ResnetBlock(channels, channels)
195        self.mid.attn_1 = AttnBlock(channels)
196        self.mid.block_2 = ResnetBlock(channels, channels)

トップレベルブロックのリスト

199        self.up = nn.ModuleList()

トップレベルブロックを作成

201        for i in reversed(range(num_resolutions)):

各トップレベルブロックは複数のResNetブロックとアップサンプリングで構成されています

203            resnet_blocks = nn.ModuleList()

ResNet ブロックを追加

205            for _ in range(n_resnet_blocks + 1):
206                resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207                channels = channels_list[i]

トップレベルブロック

209            up = nn.Module()
210            up.block = resnet_blocks

最初のブロックを除く各トップレベルブロックの最後でのアップサンプリング

212            if i != 0:
213                up.upsample = UpSample(channels)
214            else:
215                up.upsample = nn.Identity()

チェックポイントと一致するようにプリペンドを付ける

217            self.up.insert(0, up)

畳み込みによる画像空間にマッピング

220        self.norm_out = normalization(channels)
221        self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
  • z 形状付きの埋め込みテンソルです [batch_size, z_channels, z_height, z_height]
223    def forward(self, z: torch.Tensor):

channels 最初の畳み込みでにマッピング

229        h = self.conv_in(z)

ResNet ブロックには注意が必要

232        h = self.mid.block_1(h)
233        h = self.mid.attn_1(h)
234        h = self.mid.block_2(h)

トップレベルブロック

237        for up in reversed(self.up):

ResNet ブロック

239            for block in up.block:
240                h = block(h)

アップサンプリング

242            h = up.upsample(h)

正規化して画像空間にマッピング

245        h = self.norm_out(h)
246        h = swish(h)
247        img = self.conv_out(h)

250        return img

ガウス分布

253class GaussianDistribution:
  • parameters は図形の埋め込みの平均と分散の対数です [batch_size, z_channels * 2, z_height, z_height]
258    def __init__(self, parameters: torch.Tensor):

分割平均と分散対数

264        self.mean, log_var = torch.chunk(parameters, 2, dim=1)

差異の対数をクランプ

266        self.log_var = torch.clamp(log_var, -30.0, 20.0)

標準偏差の計算

268        self.std = torch.exp(0.5 * self.log_var)
270    def sample(self):

ディストリビューションからのサンプル

272        return self.mean + self.std * torch.randn_like(self.std)

アテンションブロック

275class AttnBlock(nn.Module):
  • channels はチャネル数
280    def __init__(self, channels: int):
284        super().__init__()

グループ正規化

286        self.norm = normalization(channels)

クエリ、キー、値のマッピング

288        self.q = nn.Conv2d(channels, channels, 1)
289        self.k = nn.Conv2d(channels, channels, 1)
290        self.v = nn.Conv2d(channels, channels, 1)

最終畳み込み層

292        self.proj_out = nn.Conv2d(channels, channels, 1)

アテンションスケーリングファクター

294        self.scale = channels ** -0.5
  • x 形状のテンソルです [batch_size, channels, height, width]
296    def forward(self, x: torch.Tensor):

ノーマライズ x

301        x_norm = self.norm(x)

クエリ、キー、ベクターの埋め込みを取得

303        q = self.q(x_norm)
304        k = self.k(x_norm)
305        v = self.v(x_norm)

形状を変えてクエリ、キー、ベクターの埋め込みをからへ [batch_size, channels, height, width] [batch_size, channels, height * width]

310        b, c, h, w = q.shape
311        q = q.view(b, c, h * w)
312        k = k.view(b, c, h * w)
313        v = v.view(b, c, h * w)

コンピュート

316        attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317        attn = F.softmax(attn, dim=2)

コンピュート

320        out = torch.einsum('bij,bcj->bci', attn, v)

形状を変えて元に戻す [batch_size, channels, height, width]

323        out = out.view(b, c, h, w)

最終畳み込み層

325        out = self.proj_out(out)

残余接続を追加

328        return x + out

アップサンプリングレイヤー

331class UpSample(nn.Module):
  • channels はチャネル数
335    def __init__(self, channels: int):
339        super().__init__()

コンボリューションマッピング

341        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
343    def forward(self, x: torch.Tensor):

次の倍までのアップサンプリング

348        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

コンボリューションを適用

350        return self.conv(x)

ダウンサンプリングレイヤー

353class DownSample(nn.Module):
  • channels はチャネル数
357    def __init__(self, channels: int):
361        super().__init__()

ストライドの長さがの畳み込みから、の係数でダウンサンプリングします

363        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
  • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
365    def forward(self, x: torch.Tensor):

パディングを追加

370        x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)

コンボリューションを適用

372        return self.conv(x)

リネットブロック

375class ResnetBlock(nn.Module):
  • in_channels は入力のチャンネル数
  • out_channels は出力のチャンネル数
379    def __init__(self, in_channels: int, out_channels: int):
384        super().__init__()

最初の正規化と畳み込み層

386        self.norm1 = normalization(in_channels)
387        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)

2 番目の正規化と畳み込み層

389        self.norm2 = normalization(out_channels)
390        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)

in_channels out_channels 残留接続用のマッピングレイヤへ

392        if in_channels != out_channels:
393            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394        else:
395            self.nin_shortcut = nn.Identity()
  • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
397    def forward(self, x: torch.Tensor):
402        h = x

最初の正規化と畳み込み層

405        h = self.norm1(h)
406        h = swish(h)
407        h = self.conv1(h)

2 番目の正規化と畳み込み層

410        h = self.norm2(h)
411        h = swish(h)
412        h = self.conv2(h)

残差をマッピングして追加

415        return self.nin_shortcut(x) + h

スウィッシュアクティベーション

418def swish(x: torch.Tensor):
424    return x * torch.sigmoid(x)

グループ正規化

これはヘルパー関数で、グループの数は固定されています。eps

427def normalization(channels: int):
433    return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)