U-Net for Stable Diffusion

This implements the U-Net that gives

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer

U-Net model

29class UNetModel(nn.Module):
  • in_channels is the number of channels in the input feature map
  • out_channels is the number of channels in the output feature map
  • channels is the base channel count for the model
  • n_res_blocks number of residual blocks at each level
  • attention_levels are the levels at which attention should be performed
  • channel_multipliers are the multiplicative factors for number of channels for each level
  • n_heads is the number of attention heads in the transformers
  • tf_layers is the number of transformer layers in the transformers
  • d_cond is the size of the conditional embedding in the transformers
34    def __init__(
35            self, *,
36            in_channels: int,
37            out_channels: int,
38            channels: int,
39            n_res_blocks: int,
40            attention_levels: List[int],
41            channel_multipliers: List[int],
42            n_heads: int,
43            tf_layers: int = 1,
44            d_cond: int = 768):
56        super().__init__()
57        self.channels = channels

Number of levels

60        levels = len(channel_multipliers)

Size time embeddings

62        d_time_emb = channels * 4
63        self.time_embed = nn.Sequential(
64            nn.Linear(channels, d_time_emb),
65            nn.SiLU(),
66            nn.Linear(d_time_emb, d_time_emb),
67        )

Input half of the U-Net

70        self.input_blocks = nn.ModuleList()

Initial convolution that maps the input to channels . The blocks are wrapped in TimestepEmbedSequential module because different modules have different forward function signatures; for example, convolution only accepts the feature map and residual blocks accept the feature map and time embedding. TimestepEmbedSequential calls them accordingly.

77        self.input_blocks.append(TimestepEmbedSequential(
78            nn.Conv2d(in_channels, channels, 3, padding=1)))

Number of channels at each block in the input half of U-Net

80        input_block_channels = [channels]

Number of channels at each level

82        channels_list = [channels * m for m in channel_multipliers]

Prepare levels

84        for i in range(levels):

Add the residual blocks and attentions

86            for _ in range(n_res_blocks):

Residual block maps from previous number of channels to the number of channels in the current level

89                layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
90                channels = channels_list[i]

Add transformer

92                if i in attention_levels:
93                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

Add them to the input half of the U-Net and keep track of the number of channels of its output

96                self.input_blocks.append(TimestepEmbedSequential(*layers))
97                input_block_channels.append(channels)

Down sample at all levels except last

99            if i != levels - 1:
100                self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
101                input_block_channels.append(channels)

The middle of the U-Net

104        self.middle_block = TimestepEmbedSequential(
105            ResBlock(channels, d_time_emb),
106            SpatialTransformer(channels, n_heads, tf_layers, d_cond),
107            ResBlock(channels, d_time_emb),
108        )

Second half of the U-Net

111        self.output_blocks = nn.ModuleList([])

Prepare levels in reverse order

113        for i in reversed(range(levels)):

Add the residual blocks and attentions

115            for j in range(n_res_blocks + 1):

Residual block maps from previous number of channels plus the skip connections from the input half of U-Net to the number of channels in the current level.

119                layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
120                channels = channels_list[i]

Add transformer

122                if i in attention_levels:
123                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

Up-sample at every level after last residual block except the last one. Note that we are iterating in reverse; i.e. i == 0 is the last.

127                if i != 0 and j == n_res_blocks:
128                    layers.append(UpSample(channels))

Add to the output half of the U-Net

130                self.output_blocks.append(TimestepEmbedSequential(*layers))

Final normalization and convolution

133        self.out = nn.Sequential(
134            normalization(channels),
135            nn.SiLU(),
136            nn.Conv2d(channels, out_channels, 3, padding=1),
137        )

Create sinusoidal time step embeddings

  • time_steps are the time steps of shape [batch_size]
  • max_period controls the minimum frequency of the embeddings.
139    def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):

; half the channels are sin and the other half is cos,

147        half = self.channels // 2

149        frequencies = torch.exp(
150            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
151        ).to(device=time_steps.device)

153        args = time_steps[:, None].float() * frequencies[None]

and

155        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  • x is the input feature map of shape [batch_size, channels, width, height]
  • time_steps are the time steps of shape [batch_size]
  • cond conditioning of shape [batch_size, n_cond, d_cond]
157    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):

To store the input half outputs for skip connections

164        x_input_block = []

Get time step embeddings

167        t_emb = self.time_step_embedding(time_steps)
168        t_emb = self.time_embed(t_emb)

Input half of the U-Net

171        for module in self.input_blocks:
172            x = module(x, t_emb, cond)
173            x_input_block.append(x)

Middle of the U-Net

175        x = self.middle_block(x, t_emb, cond)

Output half of the U-Net

177        for module in self.output_blocks:
178            x = torch.cat([x, x_input_block.pop()], dim=1)
179            x = module(x, t_emb, cond)

Final normalization and convolution

182        return self.out(x)

Sequential block for modules with different inputs

This sequential module can compose of different modules such as ResBlock , nn.Conv and SpatialTransformer and calls them with the matching signatures

185class TimestepEmbedSequential(nn.Sequential):
193    def forward(self, x, t_emb, cond=None):
194        for layer in self:
195            if isinstance(layer, ResBlock):
196                x = layer(x, t_emb)
197            elif isinstance(layer, SpatialTransformer):
198                x = layer(x, cond)
199            else:
200                x = layer(x)
201        return x

Up-sampling layer

204class UpSample(nn.Module):
  • channels is the number of channels
209    def __init__(self, channels: int):
213        super().__init__()

convolution mapping

215        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x is the input feature map with shape [batch_size, channels, height, width]
217    def forward(self, x: torch.Tensor):

Up-sample by a factor of

222        x = F.interpolate(x, scale_factor=2, mode="nearest")

Apply convolution

224        return self.conv(x)

Down-sampling layer

227class DownSample(nn.Module):
  • channels is the number of channels
232    def __init__(self, channels: int):
236        super().__init__()

convolution with stride length of to down-sample by a factor of

238        self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
  • x is the input feature map with shape [batch_size, channels, height, width]
240    def forward(self, x: torch.Tensor):

Apply convolution

245        return self.op(x)

ResNet Block

248class ResBlock(nn.Module):
  • channels the number of input channels
  • d_t_emb the size of timestep embeddings
  • out_channels is the number of out channels. defaults to `channels.
253    def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
259        super().__init__()

out_channels not specified

261        if out_channels is None:
262            out_channels = channels

First normalization and convolution

265        self.in_layers = nn.Sequential(
266            normalization(channels),
267            nn.SiLU(),
268            nn.Conv2d(channels, out_channels, 3, padding=1),
269        )

Time step embeddings

272        self.emb_layers = nn.Sequential(
273            nn.SiLU(),
274            nn.Linear(d_t_emb, out_channels),
275        )

Final convolution layer

277        self.out_layers = nn.Sequential(
278            normalization(out_channels),
279            nn.SiLU(),
280            nn.Dropout(0.),
281            nn.Conv2d(out_channels, out_channels, 3, padding=1)
282        )

channels to out_channels mapping layer for residual connection

285        if out_channels == channels:
286            self.skip_connection = nn.Identity()
287        else:
288            self.skip_connection = nn.Conv2d(channels, out_channels, 1)
  • x is the input feature map with shape [batch_size, channels, height, width]
  • t_emb is the time step embeddings of shape [batch_size, d_t_emb]
290    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):

Initial convolution

296        h = self.in_layers(x)

Time step embeddings

298        t_emb = self.emb_layers(t_emb).type(h.dtype)

Add time step embeddings

300        h = h + t_emb[:, :, None, None]

Final convolution

302        h = self.out_layers(h)

Add skip connection

304        return self.skip_connection(x) + h

Group normalization with float32 casting

307class GroupNorm32(nn.GroupNorm):
312    def forward(self, x):
313        return super().forward(x.float()).type(x.dtype)

Group normalization

This is a helper function, with fixed number of groups..

316def normalization(channels):
322    return GroupNorm32(32, channels)

Test sinusoidal time step embeddings

325def _test_time_embeddings():
329    import matplotlib.pyplot as plt
330
331    plt.figure(figsize=(15, 5))
332    m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
333                  channel_multipliers=[],
334                  n_heads=1, tf_layers=1, d_cond=1)
335    te = m.time_step_embedding(torch.arange(0, 1000))
336    plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
337    plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
338    plt.title("Time embeddings")
339    plt.show()

343if __name__ == '__main__':
344    _test_time_embeddings()