これは、画像空間と潜在空間のマッピングに使用されるオートエンコーダモデルを実装しています。
チェックポイントを直接読み込めるように、CompVis/Stable-Diffusionからモデル定義と命名を変更していません。
18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn25class Autoencoder(nn.Module):encoder
はエンコーダですdecoder
デコーダですemb_channels
は量子化された埋め込み空間の次元数ですz_channels
は埋め込みスペースのチャンネル数です32 def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):39 super().__init__()
40 self.encoder = encoder
41 self.decoder = decoder埋め込み空間から量子化された埋め込み空間のモーメントにマッピングするための畳み込み (平均と対数分散)
44 self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)量子化された埋め込み空間から埋め込み空間にマッピングする畳み込み
47 self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)49 def encode(self, img: torch.Tensor) -> 'GaussianDistribution':図形付きの埋め込みをする [batch_size, z_channels * 2, z_height, z_height]
56 z = self.encoder(img)量子化された埋め込み空間のモーメントを取得
58 moments = self.quant_conv(z)ディストリビューションを返す
60 return GaussianDistribution(moments)62 def decode(self, z: torch.Tensor):量子化された表現から埋め込み空間にマッピング
69 z = self.post_quant_conv(z)形状の画像をデコード [batch_size, channels, height, width]
71 return self.decoder(z)74class Encoder(nn.Module):channels
は最初の畳み込み層のチャネル数ですchannel_multipliers
は後続のブロックのチャンネル数の乗数ですn_resnet_blocks
は各解像度での再ネット層の数ですin_channels
は画像内のチャンネル数z_channels
は埋め込みスペースのチャンネル数です79 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80 in_channels: int, z_channels: int):89 super().__init__()解像度の異なるブロック数。解像度は、各トップレベルブロックの最後で半分になります
。93 n_resolutions = len(channel_multipliers)画像をマップする最初の畳み込みレイヤー channels
96 self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)各トップレベルブロックのチャンネル数
99 channels_list = [m * channels for m in [1] + channel_multipliers]トップレベルブロックのリスト
102 self.down = nn.ModuleList()トップレベルブロックを作成
104 for i in range(n_resolutions):各トップレベルブロックは複数のResNetブロックとダウンサンプリングで構成されています
106 resnet_blocks = nn.ModuleList()ResNet ブロックを追加
108 for _ in range(n_resnet_blocks):
109 resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110 channels = channels_list[i + 1]トップレベルブロック
112 down = nn.Module()
113 down.block = resnet_blocks最後を除く各トップレベルブロックの最後でのダウンサンプリング
115 if i != n_resolutions - 1:
116 down.downsample = DownSample(channels)
117 else:
118 down.downsample = nn.Identity()120 self.down.append(down)最後の ResNet ブロックには注意が必要です。
123 self.mid = nn.Module()
124 self.mid.block_1 = ResnetBlock(channels, channels)
125 self.mid.attn_1 = AttnBlock(channels)
126 self.mid.block_2 = ResnetBlock(channels, channels)畳み込みによる埋め込み空間にマッピング
129 self.norm_out = normalization(channels)
130 self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)img
形状のあるイメージテンソルです [batch_size, img_channels, img_height, img_width]
132 def forward(self, img: torch.Tensor):channels
最初の畳み込みでにマッピング
138 x = self.conv_in(img)トップレベルブロック
141 for down in self.down:ResNet ブロック
143 for block in down.block:
144 x = block(x)ダウンサンプリング
146 x = down.downsample(x)最後の ResNet ブロックには注意が必要です。
149 x = self.mid.block_1(x)
150 x = self.mid.attn_1(x)
151 x = self.mid.block_2(x)正規化して埋め込みスペースにマッピング
154 x = self.norm_out(x)
155 x = swish(x)
156 x = self.conv_out(x)159 return x162class Decoder(nn.Module):channels
は最後の畳み込み層のチャネル数ですchannel_multipliers
前のブロックのチャンネル数の乗算係数 (逆順)n_resnet_blocks
は各解像度での再ネット層の数ですout_channels
は画像内のチャンネル数z_channels
は埋め込みスペースのチャンネル数です167 def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168 out_channels: int, z_channels: int):177 super().__init__()解像度の異なるブロック数。解像度は、各トップレベルブロックの最後で半分になります
。181 num_resolutions = len(channel_multipliers)各最上位ブロック内のチャンネル数 (逆順)
184 channels_list = [m * channels for m in channel_multipliers]最上位ブロック内のチャンネル数
187 channels = channels_list[-1]埋め込みスペースをマッピングする最初の畳み込みレイヤー channels
190 self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)ResNet ブロックには注意が必要
193 self.mid = nn.Module()
194 self.mid.block_1 = ResnetBlock(channels, channels)
195 self.mid.attn_1 = AttnBlock(channels)
196 self.mid.block_2 = ResnetBlock(channels, channels)トップレベルブロックのリスト
199 self.up = nn.ModuleList()トップレベルブロックを作成
201 for i in reversed(range(num_resolutions)):各トップレベルブロックは複数のResNetブロックとアップサンプリングで構成されています
203 resnet_blocks = nn.ModuleList()ResNet ブロックを追加
205 for _ in range(n_resnet_blocks + 1):
206 resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207 channels = channels_list[i]トップレベルブロック
209 up = nn.Module()
210 up.block = resnet_blocks最初のブロックを除く各トップレベルブロックの最後でのアップサンプリング
212 if i != 0:
213 up.upsample = UpSample(channels)
214 else:
215 up.upsample = nn.Identity()チェックポイントと一致するようにプリペンドを付ける
217 self.up.insert(0, up)畳み込みによる画像空間にマッピング
220 self.norm_out = normalization(channels)
221 self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)z
形状付きの埋め込みテンソルです [batch_size, z_channels, z_height, z_height]
223 def forward(self, z: torch.Tensor):channels
最初の畳み込みでにマッピング
229 h = self.conv_in(z)ResNet ブロックには注意が必要
232 h = self.mid.block_1(h)
233 h = self.mid.attn_1(h)
234 h = self.mid.block_2(h)トップレベルブロック
237 for up in reversed(self.up):ResNet ブロック
239 for block in up.block:
240 h = block(h)アップサンプリング
242 h = up.upsample(h)正規化して画像空間にマッピング
245 h = self.norm_out(h)
246 h = swish(h)
247 img = self.conv_out(h)250 return img253class GaussianDistribution:parameters
は図形の埋め込みの平均と分散の対数です [batch_size, z_channels * 2, z_height, z_height]
258 def __init__(self, parameters: torch.Tensor):分割平均と分散対数
264 self.mean, log_var = torch.chunk(parameters, 2, dim=1)差異の対数をクランプ
266 self.log_var = torch.clamp(log_var, -30.0, 20.0)標準偏差の計算
268 self.std = torch.exp(0.5 * self.log_var)270 def sample(self):ディストリビューションからのサンプル
272 return self.mean + self.std * torch.randn_like(self.std)275class AttnBlock(nn.Module):channels
はチャネル数280 def __init__(self, channels: int):284 super().__init__()グループ正規化
286 self.norm = normalization(channels)クエリ、キー、値のマッピング
288 self.q = nn.Conv2d(channels, channels, 1)
289 self.k = nn.Conv2d(channels, channels, 1)
290 self.v = nn.Conv2d(channels, channels, 1)最終畳み込み層
292 self.proj_out = nn.Conv2d(channels, channels, 1)アテンションスケーリングファクター
294 self.scale = channels ** -0.5x
形状のテンソルです [batch_size, channels, height, width]
296 def forward(self, x: torch.Tensor):ノーマライズ x
301 x_norm = self.norm(x)クエリ、キー、ベクターの埋め込みを取得
303 q = self.q(x_norm)
304 k = self.k(x_norm)
305 v = self.v(x_norm)形状を変えてクエリ、キー、ベクターの埋め込みをからへ [batch_size, channels, height, width]
[batch_size, channels, height * width]
310 b, c, h, w = q.shape
311 q = q.view(b, c, h * w)
312 k = k.view(b, c, h * w)
313 v = v.view(b, c, h * w)コンピュート
316 attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317 attn = F.softmax(attn, dim=2)コンピュート
320 out = torch.einsum('bij,bcj->bci', attn, v)形状を変えて元に戻す [batch_size, channels, height, width]
323 out = out.view(b, c, h, w)最終畳み込み層
325 out = self.proj_out(out)残余接続を追加
328 return x + out331class UpSample(nn.Module):channels
はチャネル数335 def __init__(self, channels: int):339 super().__init__()コンボリューションマッピング
341 self.conv = nn.Conv2d(channels, channels, 3, padding=1)x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
343 def forward(self, x: torch.Tensor):次の倍までのアップサンプリング
348 x = F.interpolate(x, scale_factor=2.0, mode="nearest")コンボリューションを適用
350 return self.conv(x)353class DownSample(nn.Module):channels
はチャネル数357 def __init__(self, channels: int):361 super().__init__()ストライドの長さがの畳み込みから、の係数でダウンサンプリングします
363 self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
365 def forward(self, x: torch.Tensor):パディングを追加
370 x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)コンボリューションを適用
372 return self.conv(x)375class ResnetBlock(nn.Module):in_channels
は入力のチャンネル数out_channels
は出力のチャンネル数379 def __init__(self, in_channels: int, out_channels: int):384 super().__init__()最初の正規化と畳み込み層
386 self.norm1 = normalization(in_channels)
387 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)2 番目の正規化と畳み込み層
389 self.norm2 = normalization(out_channels)
390 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)in_channels
out_channels
残留接続用のマッピングレイヤへ
392 if in_channels != out_channels:
393 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394 else:
395 self.nin_shortcut = nn.Identity()x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
397 def forward(self, x: torch.Tensor):402 h = x最初の正規化と畳み込み層
405 h = self.norm1(h)
406 h = swish(h)
407 h = self.conv1(h)2 番目の正規化と畳み込み層
410 h = self.norm2(h)
411 h = swish(h)
412 h = self.conv2(h)残差をマッピングして追加
415 return self.nin_shortcut(x) + h418def swish(x: torch.Tensor):424 return x * torch.sigmoid(x)427def normalization(channels: int):433 return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)