U-Net for Stable Diffusion

This implements the U-Net that gives

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch as th
24import torch.nn as nn
25import torch.nn.functional as F
26
27from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer

U-Net model

30class UNetModel(nn.Module):
  • in_channels is the number of channels in the input feature map
  • out_channels is the number of channels in the output feature map
  • channels is the base channel count for the model
  • n_res_blocks number of residual blocks at each level
  • attention_levels are the levels at which attention should be performed
  • channel_multipliers are the multiplicative factors for number of channels for each level
  • n_heads the number of attention heads in the transformers
35    def __init__(
36            self, *,
37            in_channels: int,
38            out_channels: int,
39            channels: int,
40            n_res_blocks: int,
41            attention_levels: List[int],
42            channel_multipliers: List[int],
43            n_heads: int,
44            tf_layers: int = 1,
45            d_cond: int = 768):
55        super().__init__()
56        self.channels = channels

Number of levels

59        levels = len(channel_multipliers)

Size time embeddings

61        d_time_emb = channels * 4
62        self.time_embed = nn.Sequential(
63            nn.Linear(channels, d_time_emb),
64            nn.SiLU(),
65            nn.Linear(d_time_emb, d_time_emb),
66        )

Input half of the U-Net

69        self.input_blocks = nn.ModuleList()

Initial convolution that maps the input to channels . The blocks are wrapped in TimestepEmbedSequential module because different modules have different forward function signatures; for example, convolution only accepts the feature map and residual blocks accept the feature map and time embedding. TimestepEmbedSequential calls them accordingly.

76        self.input_blocks.append(TimestepEmbedSequential(
77            nn.Conv2d(in_channels, channels, 3, padding=1)))

Number of channels at each block in the input half of U-Net

79        input_block_channels = [channels]

Number of channels at each level

81        channels_list = [channels * m for m in channel_multipliers]

Prepare levels

83        for i in range(levels):

Add the residual blocks and attentions

85            for _ in range(n_res_blocks):

Residual block maps from previous number of channels to the number of channels in the current level

88                layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
89                channels = channels_list[i]

Add transformer

91                if i in attention_levels:
92                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

Add them to the input half of the U-Net and keep track of the number of channels of its output

95                self.input_blocks.append(TimestepEmbedSequential(*layers))
96                input_block_channels.append(channels)

Down sample at all levels except last

98            if i != levels - 1:
99                self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
100                input_block_channels.append(channels)

The middle of the U-Net

103        self.middle_block = TimestepEmbedSequential(
104            ResBlock(channels, d_time_emb),
105            SpatialTransformer(channels, n_heads, tf_layers, d_cond),
106            ResBlock(channels, d_time_emb),
107        )

Second half of the U-Net

110        self.output_blocks = nn.ModuleList([])

Prepare levels in reverse order

112        for i in reversed(range(levels)):

Add the residual blocks and attentions

114            for j in range(n_res_blocks + 1):

Residual block maps from previous number of channels plus the skip connections from the input half of U-Net to the number of channels in the current level.

118                layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
119                channels = channels_list[i]

Add transformer

121                if i in attention_levels:
122                    layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))

Up-sample at every level after last residual block except the last one. Note that we are iterating in reverse; i.e. i == 0 is the last.

126                if i != 0 and j == n_res_blocks:
127                    layers.append(UpSample(channels))

Add to the output half of the U-Net

129                self.output_blocks.append(TimestepEmbedSequential(*layers))

Final normalization and convolution

132        self.out = nn.Sequential(
133            normalization(channels),
134            nn.SiLU(),
135            nn.Conv2d(channels, out_channels, 3, padding=1),
136        )

Create sinusoidal time step embeddings

  • time_steps are the time steps of shape [batch_size]
  • max_period controls the minimum frequency of the embeddings.
138    def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):

; half the channels are sin and the other half is cos,

146        half = self.channels // 2

148        frequencies = torch.exp(
149            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
150        ).to(device=time_steps.device)

152        args = time_steps[:, None].float() * frequencies[None]

and

154        return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  • x is the input feature map of shape [batch_size, channels, width, height]
  • time_steps are the time steps of shape [batch_size]
  • cond conditioning of shape [batch_size, n_cond, d_cond]
156    def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):

To store the input half outputs for skip connections

163        x_input_block = []

Get time step embeddings

166        t_emb = self.time_step_embedding(time_steps)
167        t_emb = self.time_embed(t_emb)

Input half of the U-Net

170        for module in self.input_blocks:
171            x = module(x, t_emb, cond)
172            x_input_block.append(x)

Middle of the U-Net

174        x = self.middle_block(x, t_emb, cond)

Output half of the U-Net

176        for module in self.output_blocks:
177            x = th.cat([x, x_input_block.pop()], dim=1)
178            x = module(x, t_emb, cond)

Final normalization and convolution

181        return self.out(x)

Sequential block for modules with different inputs

This sequential module can compose of different modules suck as ResBlock , nn.Conv and SpatialTransformer and calls them with the matching signatures

184class TimestepEmbedSequential(nn.Sequential):
192    def forward(self, x, t_emb, cond=None):
193        for layer in self:
194            if isinstance(layer, ResBlock):
195                x = layer(x, t_emb)
196            elif isinstance(layer, SpatialTransformer):
197                x = layer(x, cond)
198            else:
199                x = layer(x)
200        return x

Up-sampling layer

203class UpSample(nn.Module):
  • channels is the number of channels
208    def __init__(self, channels: int):
212        super().__init__()

convolution mapping

214        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x is the input feature map with shape [batch_size, channels, height, width]
216    def forward(self, x: torch.Tensor):

Up-sample by a factor of

221        x = F.interpolate(x, scale_factor=2, mode="nearest")

Apply convolution

223        return self.conv(x)

Down-sampling layer

226class DownSample(nn.Module):
  • channels is the number of channels
231    def __init__(self, channels: int):
235        super().__init__()

convolution with stride length of to down-sample by a factor of

237        self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
  • x is the input feature map with shape [batch_size, channels, height, width]
239    def forward(self, x: torch.Tensor):

Apply convolution

244        return self.op(x)

ResNet Block

247class ResBlock(nn.Module):
  • channels the number of input channels
  • d_t_emb the size of timestep embeddings
  • out_channels is the number of out channels. defaults to `channels.
252    def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
258        super().__init__()

out_channels not specified

260        if out_channels is None:
261            out_channels = channels

First normalization and convolution

264        self.in_layers = nn.Sequential(
265            normalization(channels),
266            nn.SiLU(),
267            nn.Conv2d(channels, out_channels, 3, padding=1),
268        )

Time step embeddings

271        self.emb_layers = nn.Sequential(
272            nn.SiLU(),
273            nn.Linear(d_t_emb, out_channels),
274        )

Final convolution layer

276        self.out_layers = nn.Sequential(
277            normalization(out_channels),
278            nn.SiLU(),
279            nn.Dropout(0.),
280            nn.Conv2d(out_channels, out_channels, 3, padding=1)
281        )

channels to out_channels mapping layer for residual connection

284        if out_channels == channels:
285            self.skip_connection = nn.Identity()
286        else:
287            self.skip_connection = nn.Conv2d(channels, out_channels, 1)
  • x is the input feature map with shape [batch_size, channels, height, width]
  • t_emb is the time step embeddings of shape [batch_size, d_t_emb]
289    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):

Initial convolution

295        h = self.in_layers(x)

Time step embeddings

297        t_emb = self.emb_layers(t_emb).type(h.dtype)

Add time step embeddings

299        h = h + t_emb[:, :, None, None]

Final convolution

301        h = self.out_layers(h)

Add skip connection

303        return self.skip_connection(x) + h

Group normalization with float32 casting

306class GroupNorm32(nn.GroupNorm):
311    def forward(self, x):
312        return super().forward(x.float()).type(x.dtype)

Group normalization

This is a helper function, with fixed number of groups..

315def normalization(channels):
321    return GroupNorm32(32, channels)

Test sinusoidal time step embeddings

324def _test_time_embeddings():
328    import matplotlib.pyplot as plt
329
330    plt.figure(figsize=(15, 5))
331    m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
332                  channel_multipliers=[],
333                  n_heads=1, tf_layers=1, d_cond=1)
334    te = m.time_step_embedding(torch.arange(0, 1000))
335    plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
336    plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
337    plt.title("Time embeddings")
338    plt.show()

342if __name__ == '__main__':
343    _test_time_embeddings()