これにより、以下の U-Net が実装されます
チェックポイントを直接読み込めるように、CompVis/Stable-Diffusionからモデル定義と命名を変更していません。
18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer
29class UNetModel(nn.Module):
in_channels
は、入力フィーチャマップのチャネル数ですout_channels
は出力フィーチャマップのチャネル数です。channels
はモデルのベースチャンネル数n_res_blocks
各レベルの残差ブロック数attention_levels
注意すべきレベルはどれぐらいのレベルかchannel_multipliers
は各レベルのチャンネル数の乗法係数n_heads
は変圧器内のアテンションヘッドの数ですtf_layers
は変圧器内の変圧器層の数です。d_cond
はトランスフォーマー内の条件付き埋め込みのサイズです34 def __init__(
35 self, *,
36 in_channels: int,
37 out_channels: int,
38 channels: int,
39 n_res_blocks: int,
40 attention_levels: List[int],
41 channel_multipliers: List[int],
42 n_heads: int,
43 tf_layers: int = 1,
44 d_cond: int = 768):
56 super().__init__()
57 self.channels = channels
レベル数
60 levels = len(channel_multipliers)
サイズタイム埋め込み
62 d_time_emb = channels * 4
63 self.time_embed = nn.Sequential(
64 nn.Linear(channels, d_time_emb),
65 nn.SiLU(),
66 nn.Linear(d_time_emb, d_time_emb),
67 )
U ネットの半分を入力
70 self.input_blocks = nn.ModuleList()
入力をにマップする初期畳み込み。channels
TimestepEmbedSequential
モジュールが異なればフォワード関数のシグネチャも異なるため、ブロックはモジュールでラップされます。たとえば、畳み込みは特徴マップのみを受け入れ、残差ブロックは特徴マップと時間埋め込みを受け入れます。TimestepEmbedSequential
それに応じて呼び出します。
77 self.input_blocks.append(TimestepEmbedSequential(
78 nn.Conv2d(in_channels, channels, 3, padding=1)))
U-Netの入力半分の各ブロックのチャンネル数
80 input_block_channels = [channels]
各レベルのチャンネル数
82 channels_list = [channels * m for m in channel_multipliers]
レベルを準備
84 for i in range(levels):
残留ブロックとアテンションを追加
86 for _ in range(n_res_blocks):
残差ブロックは、前のチャンネル数から現在のレベルのチャンネル数にマッピングされます
89 layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
90 channels = channels_list[i]
変圧器を追加
92 if i in attention_levels:
93 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
それらをU-Netの入力半分に追加して、その出力のチャンネル数を記録しておきます。
96 self.input_blocks.append(TimestepEmbedSequential(*layers))
97 input_block_channels.append(channels)
最後のレベルを除くすべてのレベルでダウンサンプル
99 if i != levels - 1:
100 self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
101 input_block_channels.append(channels)
Uネットの真ん中
104 self.middle_block = TimestepEmbedSequential(
105 ResBlock(channels, d_time_emb),
106 SpatialTransformer(channels, n_heads, tf_layers, d_cond),
107 ResBlock(channels, d_time_emb),
108 )
Uネット後半
111 self.output_blocks = nn.ModuleList([])
レベルを逆の順序で準備する
113 for i in reversed(range(levels)):
残留ブロックとアテンションを追加
115 for j in range(n_res_blocks + 1):
前のチャンネル数の残差ブロックマップに U-Net の入力半分からのスキップ接続を加えたものから現在のレベルのチャンネル数までマップされます。
119 layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
120 channels = channels_list[i]
変圧器を追加
122 if i in attention_levels:
123 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
最後の残差ブロックを除く最後の残差ブロックの後のすべてのレベルでアップサンプリングします。逆に繰り返していることに注意してください。つまりi == 0
、最後です
127 if i != 0 and j == n_res_blocks:
128 layers.append(UpSample(channels))
U-Netの出力半分に追加
130 self.output_blocks.append(TimestepEmbedSequential(*layers))
最終的な正規化と畳み込み
133 self.out = nn.Sequential(
134 normalization(channels),
135 nn.SiLU(),
136 nn.Conv2d(channels, out_channels, 3, padding=1),
137 )
139 def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
; チャネルの半分は罪で、残りの半分はコス
147 half = self.channels // 2
149 frequencies = torch.exp(
150 -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
151 ).to(device=time_steps.device)
153 args = time_steps[:, None].float() * frequencies[None]
と
155 return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
x
はシェイプの入力フィーチャマップです [batch_size, channels, width, height]
time_steps
形状のタイムステップです [batch_size]
cond
形状のコンディショニング [batch_size, n_cond, d_cond]
157 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):
スキップ接続の入力ハーフ出力を保存するには
164 x_input_block = []
タイムステップの埋め込みを取得
167 t_emb = self.time_step_embedding(time_steps)
168 t_emb = self.time_embed(t_emb)
U ネットの半分を入力
171 for module in self.input_blocks:
172 x = module(x, t_emb, cond)
173 x_input_block.append(x)
U-ネットの真ん中
175 x = self.middle_block(x, t_emb, cond)
U-ネットの出力半分
177 for module in self.output_blocks:
178 x = torch.cat([x, x_input_block.pop()], dim=1)
179 x = module(x, t_emb, cond)
最終的な正規化と畳み込み
182 return self.out(x)
このシーケンシャルモジュールは、、nn.Conv
SpatialTransformer
などのさまざまなモジュールで構成できResBlock
、それらを対応するシグネチャで呼び出すことができます。
185class TimestepEmbedSequential(nn.Sequential):
193 def forward(self, x, t_emb, cond=None):
194 for layer in self:
195 if isinstance(layer, ResBlock):
196 x = layer(x, t_emb)
197 elif isinstance(layer, SpatialTransformer):
198 x = layer(x, cond)
199 else:
200 x = layer(x)
201 return x
204class UpSample(nn.Module):
channels
はチャネル数209 def __init__(self, channels: int):
213 super().__init__()
コンボリューションマッピング
215 self.conv = nn.Conv2d(channels, channels, 3, padding=1)
x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
217 def forward(self, x: torch.Tensor):
次の倍までのアップサンプリング
222 x = F.interpolate(x, scale_factor=2, mode="nearest")
コンボリューションを適用
224 return self.conv(x)
227class DownSample(nn.Module):
channels
はチャネル数232 def __init__(self, channels: int):
236 super().__init__()
ストライドの長さがの畳み込みから、の係数でダウンサンプリングします
238 self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
240 def forward(self, x: torch.Tensor):
コンボリューションを適用
245 return self.op(x)
248class ResBlock(nn.Module):
253 def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
259 super().__init__()
out_channels
指定なし
261 if out_channels is None:
262 out_channels = channels
最初の正規化と畳み込み
265 self.in_layers = nn.Sequential(
266 normalization(channels),
267 nn.SiLU(),
268 nn.Conv2d(channels, out_channels, 3, padding=1),
269 )
タイムステップ埋め込み
272 self.emb_layers = nn.Sequential(
273 nn.SiLU(),
274 nn.Linear(d_t_emb, out_channels),
275 )
最終畳み込み層
277 self.out_layers = nn.Sequential(
278 normalization(out_channels),
279 nn.SiLU(),
280 nn.Dropout(0.),
281 nn.Conv2d(out_channels, out_channels, 3, padding=1)
282 )
channels
out_channels
残留接続用のマッピングレイヤへ
285 if out_channels == channels:
286 self.skip_connection = nn.Identity()
287 else:
288 self.skip_connection = nn.Conv2d(channels, out_channels, 1)
x
形状付きの入力フィーチャマップです [batch_size, channels, height, width]
t_emb
形状のタイムステップ埋め込みです [batch_size, d_t_emb]
290 def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
初期コンボリューション
296 h = self.in_layers(x)
タイムステップ埋め込み
298 t_emb = self.emb_layers(t_emb).type(h.dtype)
タイムステップ埋め込みの追加
300 h = h + t_emb[:, :, None, None]
最終畳み込み
302 h = self.out_layers(h)
スキップ接続を追加
304 return self.skip_connection(x) + h
307class GroupNorm32(nn.GroupNorm):
312 def forward(self, x):
313 return super().forward(x.float()).type(x.dtype)
316def normalization(channels):
322 return GroupNorm32(32, channels)
正弦波タイムステップ埋め込みのテスト
325def _test_time_embeddings():
329 import matplotlib.pyplot as plt
330
331 plt.figure(figsize=(15, 5))
332 m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
333 channel_multipliers=[],
334 n_heads=1, tf_layers=1, d_cond=1)
335 te = m.time_step_embedding(torch.arange(0, 1000))
336 plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
337 plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
338 plt.title("Time embeddings")
339 plt.show()
343if __name__ == '__main__':
344 _test_time_embeddings()