安定した拡散を実現するUネット

これにより、以下の U-Net が実装されます

チェックポイントを直接読み込めるように、CompVis/Stable-Diffusionからモデル定義と命名を変更していません

18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer

U-ネットモデル

29class UNetModel(nn.Module):
  • in_channels は、入力フィーチャマップのチャネル数です
  • out_channels は出力フィーチャマップのチャネル数です。
  • channels はモデルのベースチャンネル数
  • n_res_blocks 各レベルの残差ブロック数
  • attention_levels 注意すべきレベルはどれぐらいのレベルか
  • channel_multipliers は各レベルのチャンネル数の乗法係数
  • n_heads は変圧器内のアテンションヘッドの数です
  • tf_layers は変圧器内の変圧器層の数です。
  • d_cond はトランスフォーマー内の条件付き埋め込みのサイズです
34    def __init__(
35            self, *,
36            in_channels: int,
37            out_channels: int,
38            channels: int,
39            n_res_blocks: int,
40            attention_levels: List[int],
41            channel_multipliers: List[int],
42            n_heads: int,
43            tf_layers: int = 1,
44            d_cond: int = 768):
56        super().__init__()
57        self.channels = channels

レベル数

60        levels = len(channel_multipliers)

サイズタイム埋め込み

62        d_time_emb = channels * 4
63        self.time_embed = nn.Sequential(
64            nn.Linear(channels, d_time_emb),
65            nn.SiLU(),
66            nn.Linear(d_time_emb, d_time_emb),
67        )

U ネットの半分を入力

70        self.input_blocks = nn.ModuleList()

入力をにマップする初期畳み込み。channels TimestepEmbedSequential モジュールが異なればフォワード関数のシグネチャも異なるため、ブロックはモジュールでラップされます。たとえば、畳み込みは特徴マップのみを受け入れ、残差ブロックは特徴マップと時間埋め込みを受け入れます。TimestepEmbedSequential それに応じて呼び出します。

77        self.input_blocks.append(TimestepEmbedSequential(
78            nn.Conv2d(in_channels, channels, 3, padding=1)))

U-Netの入力半分の各ブロックのチャンネル数

80        input_block_channels = [channels]

各レベルのチャンネル数

82        channels_list = [channels * m for m in channel_multipliers]

レベルを準備

84        for i in range(levels):

残留ブロックとアテンションを追加

86            for _ in range(n_res_blocks):

残差ブロックは、前のチャンネル数から現在のレベルのチャンネル数にマッピングされます

89                layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
90                channels = channels_list[i]

変圧器を追加

92                if i in attention_levels:
93                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

それらをU-Netの入力半分に追加して、その出力のチャンネル数を記録しておきます。

96                self.input_blocks.append(TimestepEmbedSequential(*layers))
97                input_block_channels.append(channels)

最後のレベルを除くすべてのレベルでダウンサンプル

99            if i != levels - 1:
100                self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
101                input_block_channels.append(channels)

Uネットの真ん中

104        self.middle_block = TimestepEmbedSequential(
105            ResBlock(channels, d_time_emb),
106            SpatialTransformer(channels, n_heads, tf_layers, d_cond),
107            ResBlock(channels, d_time_emb),
108        )

Uネット後半

111        self.output_blocks = nn.ModuleList([])

レベルを逆の順序で準備する

113        for i in reversed(range(levels)):

残留ブロックとアテンションを追加

115            for j in range(n_res_blocks + 1):

前のチャンネル数の残差ブロックマップに U-Net の入力半分からのスキップ接続を加えたものから現在のレベルのチャンネル数までマップされます。

119                layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
120                channels = channels_list[i]

変圧器を追加

122                if i in attention_levels:
123                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

最後の残差ブロックを除く最後の残差ブロックの後のすべてのレベルでアップサンプリングします。逆に繰り返していることに注意してください。つまりi == 0 、最後です

127                if i != 0 and j == n_res_blocks:
128                    layers.append(UpSample(channels))

U-Netの出力半分に追加

130                self.output_blocks.append(TimestepEmbedSequential(*layers))

最終的な正規化と畳み込み

133        self.out = nn.Sequential(
134            normalization(channels),
135            nn.SiLU(),
136            nn.Conv2d(channels, out_channels, 3, padding=1),
137        )

正弦波タイムステップ埋め込みの作成

  • time_steps 形状のタイムステップです [batch_size]
  • max_period 埋め込みの最小頻度を制御します。
  • 139    def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):

    ; チャネルの半分は罪で、残りの半分はコス

    147        half = self.channels // 2

    149        frequencies = torch.exp(
    150            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    151        ).to(device=time_steps.device)

    153        args = time_steps[:, None].float() * frequencies[None]

    155        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    • x はシェイプの入力フィーチャマップです [batch_size, channels, width, height]
    • time_steps 形状のタイムステップです [batch_size]
    • cond 形状のコンディショニング [batch_size, n_cond, d_cond]
    157    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):

    スキップ接続の入力ハーフ出力を保存するには

    164        x_input_block = []

    タイムステップの埋め込みを取得

    167        t_emb = self.time_step_embedding(time_steps)
    168        t_emb = self.time_embed(t_emb)

    U ネットの半分を入力

    171        for module in self.input_blocks:
    172            x = module(x, t_emb, cond)
    173            x_input_block.append(x)

    U-ネットの真ん中

    175        x = self.middle_block(x, t_emb, cond)

    U-ネットの出力半分

    177        for module in self.output_blocks:
    178            x = torch.cat([x, x_input_block.pop()], dim=1)
    179            x = module(x, t_emb, cond)

    最終的な正規化と畳み込み

    182        return self.out(x)

    入力の異なるモジュール用のシーケンシャルブロック

    このシーケンシャルモジュールは、、nn.Conv SpatialTransformer などのさまざまなモジュールで構成できResBlock 、それらを対応するシグネチャで呼び出すことができます。

    185class TimestepEmbedSequential(nn.Sequential):
    193    def forward(self, x, t_emb, cond=None):
    194        for layer in self:
    195            if isinstance(layer, ResBlock):
    196                x = layer(x, t_emb)
    197            elif isinstance(layer, SpatialTransformer):
    198                x = layer(x, cond)
    199            else:
    200                x = layer(x)
    201        return x

    アップサンプリングレイヤー

    204class UpSample(nn.Module):
    • channels はチャネル数
    209    def __init__(self, channels: int):
    213        super().__init__()

    コンボリューションマッピング

    215        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
    • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
    217    def forward(self, x: torch.Tensor):

    次の倍までのアップサンプリング

    222        x = F.interpolate(x, scale_factor=2, mode="nearest")

    コンボリューションを適用

    224        return self.conv(x)

    ダウンサンプリングレイヤー

    227class DownSample(nn.Module):
    • channels はチャネル数
    232    def __init__(self, channels: int):
    236        super().__init__()

    ストライドの長さがの畳み込みから、の係数でダウンサンプリングします

    238        self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
    • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
    240    def forward(self, x: torch.Tensor):

    コンボリューションを適用

    245        return self.op(x)

    リネットブロック

    248class ResBlock(nn.Module):
    • channels 入力チャンネル数
    • d_t_emb タイムステップ埋め込みのサイズ
  • out_channels は出力チャンネルの数です。デフォルトは `channelsです。
  • 253    def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
    259        super().__init__()

    out_channels 指定なし

    261        if out_channels is None:
    262            out_channels = channels

    最初の正規化と畳み込み

    265        self.in_layers = nn.Sequential(
    266            normalization(channels),
    267            nn.SiLU(),
    268            nn.Conv2d(channels, out_channels, 3, padding=1),
    269        )

    タイムステップ埋め込み

    272        self.emb_layers = nn.Sequential(
    273            nn.SiLU(),
    274            nn.Linear(d_t_emb, out_channels),
    275        )

    最終畳み込み層

    277        self.out_layers = nn.Sequential(
    278            normalization(out_channels),
    279            nn.SiLU(),
    280            nn.Dropout(0.),
    281            nn.Conv2d(out_channels, out_channels, 3, padding=1)
    282        )

    channels out_channels 残留接続用のマッピングレイヤへ

    285        if out_channels == channels:
    286            self.skip_connection = nn.Identity()
    287        else:
    288            self.skip_connection = nn.Conv2d(channels, out_channels, 1)
    • x 形状付きの入力フィーチャマップです [batch_size, channels, height, width]
    • t_emb 形状のタイムステップ埋め込みです [batch_size, d_t_emb]
    290    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):

    初期コンボリューション

    296        h = self.in_layers(x)

    タイムステップ埋め込み

    298        t_emb = self.emb_layers(t_emb).type(h.dtype)

    タイムステップ埋め込みの追加

    300        h = h + t_emb[:, :, None, None]

    最終畳み込み

    302        h = self.out_layers(h)

    スキップ接続を追加

    304        return self.skip_connection(x) + h

    float32 キャスティングによるグループ正規化

    307class GroupNorm32(nn.GroupNorm):
    312    def forward(self, x):
    313        return super().forward(x.float()).type(x.dtype)

    グループ正規化

    これはグループ数が固定されたヘルパー関数です。

    316def normalization(channels):
    322    return GroupNorm32(32, channels)

    正弦波タイムステップ埋め込みのテスト

    325def _test_time_embeddings():
    329    import matplotlib.pyplot as plt
    330
    331    plt.figure(figsize=(15, 5))
    332    m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
    333                  channel_multipliers=[],
    334                  n_heads=1, tf_layers=1, d_cond=1)
    335    te = m.time_step_embedding(torch.arange(0, 1000))
    336    plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
    337    plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
    338    plt.title("Time embeddings")
    339    plt.show()

    343if __name__ == '__main__':
    344    _test_time_embeddings()