This implements the U-Net that gives
We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.
18import math
19from typing import List
20
21import numpy as np
22import torch
23import torch.nn as nn
24import torch.nn.functional as F
25
26from labml_nn.diffusion.stable_diffusion.model.unet_attention import SpatialTransformer
29class UNetModel(nn.Module):
in_channels
is the number of channels in the input feature map out_channels
is the number of channels in the output feature map channels
is the base channel count for the model n_res_blocks
number of residual blocks at each level attention_levels
are the levels at which attention should be performed channel_multipliers
are the multiplicative factors for number of channels for each level n_heads
the number of attention heads in the transformers34 def __init__(
35 self, *,
36 in_channels: int,
37 out_channels: int,
38 channels: int,
39 n_res_blocks: int,
40 attention_levels: List[int],
41 channel_multipliers: List[int],
42 n_heads: int,
43 tf_layers: int = 1,
44 d_cond: int = 768):
54 super().__init__()
55 self.channels = channels
Number of levels
58 levels = len(channel_multipliers)
Size time embeddings
60 d_time_emb = channels * 4
61 self.time_embed = nn.Sequential(
62 nn.Linear(channels, d_time_emb),
63 nn.SiLU(),
64 nn.Linear(d_time_emb, d_time_emb),
65 )
Input half of the U-Net
68 self.input_blocks = nn.ModuleList()
Initial convolution that maps the input to channels
. The blocks are wrapped in TimestepEmbedSequential
module because different modules have different forward function signatures; for example, convolution only accepts the feature map and residual blocks accept the feature map and time embedding. TimestepEmbedSequential
calls them accordingly.
75 self.input_blocks.append(TimestepEmbedSequential(
76 nn.Conv2d(in_channels, channels, 3, padding=1)))
Number of channels at each block in the input half of U-Net
78 input_block_channels = [channels]
Number of channels at each level
80 channels_list = [channels * m for m in channel_multipliers]
Prepare levels
82 for i in range(levels):
Add the residual blocks and attentions
84 for _ in range(n_res_blocks):
Residual block maps from previous number of channels to the number of channels in the current level
87 layers = [ResBlock(channels, d_time_emb, out_channels=channels_list[i])]
88 channels = channels_list[i]
Add transformer
90 if i in attention_levels:
91 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
Add them to the input half of the U-Net and keep track of the number of channels of its output
94 self.input_blocks.append(TimestepEmbedSequential(*layers))
95 input_block_channels.append(channels)
Down sample at all levels except last
97 if i != levels - 1:
98 self.input_blocks.append(TimestepEmbedSequential(DownSample(channels)))
99 input_block_channels.append(channels)
The middle of the U-Net
102 self.middle_block = TimestepEmbedSequential(
103 ResBlock(channels, d_time_emb),
104 SpatialTransformer(channels, n_heads, tf_layers, d_cond),
105 ResBlock(channels, d_time_emb),
106 )
Second half of the U-Net
109 self.output_blocks = nn.ModuleList([])
Prepare levels in reverse order
111 for i in reversed(range(levels)):
Add the residual blocks and attentions
113 for j in range(n_res_blocks + 1):
Residual block maps from previous number of channels plus the skip connections from the input half of U-Net to the number of channels in the current level.
117 layers = [ResBlock(channels + input_block_channels.pop(), d_time_emb, out_channels=channels_list[i])]
118 channels = channels_list[i]
Add transformer
120 if i in attention_levels:
121 layers.append(SpatialTransformer(channels, n_heads, tf_layers, d_cond))
Up-sample at every level after last residual block except the last one. Note that we are iterating in reverse; i.e. i == 0
is the last.
125 if i != 0 and j == n_res_blocks:
126 layers.append(UpSample(channels))
Add to the output half of the U-Net
128 self.output_blocks.append(TimestepEmbedSequential(*layers))
Final normalization and convolution
131 self.out = nn.Sequential(
132 normalization(channels),
133 nn.SiLU(),
134 nn.Conv2d(channels, out_channels, 3, padding=1),
135 )
time_steps
are the time steps of shape [batch_size]
max_period
controls the minimum frequency of the embeddings.137 def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
; half the channels are sin and the other half is cos,
145 half = self.channels // 2
147 frequencies = torch.exp(
148 -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
149 ).to(device=time_steps.device)
151 args = time_steps[:, None].float() * frequencies[None]
and
153 return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
x
is the input feature map of shape [batch_size, channels, width, height]
time_steps
are the time steps of shape [batch_size]
cond
conditioning of shape [batch_size, n_cond, d_cond]
155 def forward(self, x: torch.Tensor, time_steps: torch.Tensor, cond: torch.Tensor):
To store the input half outputs for skip connections
162 x_input_block = []
Get time step embeddings
165 t_emb = self.time_step_embedding(time_steps)
166 t_emb = self.time_embed(t_emb)
Input half of the U-Net
169 for module in self.input_blocks:
170 x = module(x, t_emb, cond)
171 x_input_block.append(x)
Middle of the U-Net
173 x = self.middle_block(x, t_emb, cond)
Output half of the U-Net
175 for module in self.output_blocks:
176 x = torch.cat([x, x_input_block.pop()], dim=1)
177 x = module(x, t_emb, cond)
Final normalization and convolution
180 return self.out(x)
This sequential module can compose of different modules suck as ResBlock
, nn.Conv
and SpatialTransformer
and calls them with the matching signatures
183class TimestepEmbedSequential(nn.Sequential):
191 def forward(self, x, t_emb, cond=None):
192 for layer in self:
193 if isinstance(layer, ResBlock):
194 x = layer(x, t_emb)
195 elif isinstance(layer, SpatialTransformer):
196 x = layer(x, cond)
197 else:
198 x = layer(x)
199 return x
202class UpSample(nn.Module):
channels
is the number of channels207 def __init__(self, channels: int):
211 super().__init__()
convolution mapping
213 self.conv = nn.Conv2d(channels, channels, 3, padding=1)
x
is the input feature map with shape [batch_size, channels, height, width]
215 def forward(self, x: torch.Tensor):
Up-sample by a factor of
220 x = F.interpolate(x, scale_factor=2, mode="nearest")
Apply convolution
222 return self.conv(x)
225class DownSample(nn.Module):
channels
is the number of channels230 def __init__(self, channels: int):
234 super().__init__()
convolution with stride length of to down-sample by a factor of
236 self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
x
is the input feature map with shape [batch_size, channels, height, width]
238 def forward(self, x: torch.Tensor):
Apply convolution
243 return self.op(x)
246class ResBlock(nn.Module):
channels
the number of input channels d_t_emb
the size of timestep embeddings out_channels
is the number of out channels. defaults to `channels.251 def __init__(self, channels: int, d_t_emb: int, *, out_channels=None):
257 super().__init__()
out_channels
not specified
259 if out_channels is None:
260 out_channels = channels
First normalization and convolution
263 self.in_layers = nn.Sequential(
264 normalization(channels),
265 nn.SiLU(),
266 nn.Conv2d(channels, out_channels, 3, padding=1),
267 )
Time step embeddings
270 self.emb_layers = nn.Sequential(
271 nn.SiLU(),
272 nn.Linear(d_t_emb, out_channels),
273 )
Final convolution layer
275 self.out_layers = nn.Sequential(
276 normalization(out_channels),
277 nn.SiLU(),
278 nn.Dropout(0.),
279 nn.Conv2d(out_channels, out_channels, 3, padding=1)
280 )
channels
to out_channels
mapping layer for residual connection
283 if out_channels == channels:
284 self.skip_connection = nn.Identity()
285 else:
286 self.skip_connection = nn.Conv2d(channels, out_channels, 1)
x
is the input feature map with shape [batch_size, channels, height, width]
t_emb
is the time step embeddings of shape [batch_size, d_t_emb]
288 def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
Initial convolution
294 h = self.in_layers(x)
Time step embeddings
296 t_emb = self.emb_layers(t_emb).type(h.dtype)
Add time step embeddings
298 h = h + t_emb[:, :, None, None]
Final convolution
300 h = self.out_layers(h)
Add skip connection
302 return self.skip_connection(x) + h
305class GroupNorm32(nn.GroupNorm):
310 def forward(self, x):
311 return super().forward(x.float()).type(x.dtype)
314def normalization(channels):
320 return GroupNorm32(32, channels)
Test sinusoidal time step embeddings
323def _test_time_embeddings():
327 import matplotlib.pyplot as plt
328
329 plt.figure(figsize=(15, 5))
330 m = UNetModel(in_channels=1, out_channels=1, channels=320, n_res_blocks=1, attention_levels=[],
331 channel_multipliers=[],
332 n_heads=1, tf_layers=1, d_cond=1)
333 te = m.time_step_embedding(torch.arange(0, 1000))
334 plt.plot(np.arange(1000), te[:, [50, 100, 190, 260]].numpy())
335 plt.legend(["dim %d" % p for p in [50, 100, 190, 260]])
336 plt.title("Time embeddings")
337 plt.show()
341if __name__ == '__main__':
342 _test_time_embeddings()