Autoencoder for Stable Diffusion

This implements the auto-encoder model used to map between image space and latent space.

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

18from typing import List
19
20import torch
21import torch.nn.functional as F
22from torch import nn

Autoencoder

This consists of the encoder and decoder modules.

25class Autoencoder(nn.Module):
  • encoder is the encoder
  • decoder is the decoder
  • emb_channels is the number of dimensions in the quantized embedding space
  • z_channels is the number of channels in the embedding space
32    def __init__(self, encoder: 'Encoder', decoder: 'Decoder', emb_channels: int, z_channels: int):
39        super().__init__()
40        self.encoder = encoder
41        self.decoder = decoder

Convolution to map from embedding space to quantized embedding space moments (mean and log variance)

44        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)

Convolution to map from quantized embedding space back to embedding space

47        self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)

Encode images to latent representation

  • img is the image tensor with shape [batch_size, img_channels, img_height, img_width]
49    def encode(self, img: torch.Tensor) -> 'GaussianDistribution':

Get embeddings with shape [batch_size, z_channels * 2, z_height, z_height]

56        z = self.encoder(img)

Get the moments in the quantized embedding space

58        moments = self.quant_conv(z)

Return the distribution

60        return GaussianDistribution(moments)

Decode images from latent representation

  • z is the latent representation with shape [batch_size, emb_channels, z_height, z_height]
62    def decode(self, z: torch.Tensor):

Map to embedding space from the quantized representation

69        z = self.post_quant_conv(z)

Decode the image of shape [batch_size, channels, height, width]

71        return self.decoder(z)

Encoder module

74class Encoder(nn.Module):
  • channels is the number of channels in the first convolution layer
  • channel_multipliers are the multiplicative factors for the number of channels in the subsequent blocks
  • n_resnet_blocks is the number of resnet layers at each resolution
  • in_channels is the number of channels in the image
  • z_channels is the number of channels in the embedding space
79    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
80                 in_channels: int, z_channels: int):
89        super().__init__()

Number of blocks of different resolutions. The resolution is halved at the end each top level block

93        n_resolutions = len(channel_multipliers)

Initial convolution layer that maps the image to channels

96        self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)

Number of channels in each top level block

99        channels_list = [m * channels for m in [1] + channel_multipliers]

List of top-level blocks

102        self.down = nn.ModuleList()

Create top-level blocks

104        for i in range(n_resolutions):

Each top level block consists of multiple ResNet Blocks and down-sampling

106            resnet_blocks = nn.ModuleList()

Add ResNet Blocks

108            for _ in range(n_resnet_blocks):
109                resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
110                channels = channels_list[i + 1]

Top-level block

112            down = nn.Module()
113            down.block = resnet_blocks

Down-sampling at the end of each top level block except the last

115            if i != n_resolutions - 1:
116                down.downsample = DownSample(channels)
117            else:
118                down.downsample = nn.Identity()

120            self.down.append(down)

Final ResNet blocks with attention

123        self.mid = nn.Module()
124        self.mid.block_1 = ResnetBlock(channels, channels)
125        self.mid.attn_1 = AttnBlock(channels)
126        self.mid.block_2 = ResnetBlock(channels, channels)

Map to embedding space with a convolution

129        self.norm_out = normalization(channels)
130        self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
  • img is the image tensor with shape [batch_size, img_channels, img_height, img_width]
132    def forward(self, img: torch.Tensor):

Map to channels with the initial convolution

138        x = self.conv_in(img)

Top-level blocks

141        for down in self.down:

ResNet Blocks

143            for block in down.block:
144                x = block(x)

Down-sampling

146            x = down.downsample(x)

Final ResNet blocks with attention

149        x = self.mid.block_1(x)
150        x = self.mid.attn_1(x)
151        x = self.mid.block_2(x)

Normalize and map to embedding space

154        x = self.norm_out(x)
155        x = swish(x)
156        x = self.conv_out(x)

159        return x

Decoder module

162class Decoder(nn.Module):
  • channels is the number of channels in the final convolution layer
  • channel_multipliers are the multiplicative factors for the number of channels in the previous blocks, in reverse order
  • n_resnet_blocks is the number of resnet layers at each resolution
  • out_channels is the number of channels in the image
  • z_channels is the number of channels in the embedding space
167    def __init__(self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int,
168                 out_channels: int, z_channels: int):
177        super().__init__()

Number of blocks of different resolutions. The resolution is halved at the end each top level block

181        num_resolutions = len(channel_multipliers)

Number of channels in each top level block, in the reverse order

184        channels_list = [m * channels for m in channel_multipliers]

Number of channels in the top-level block

187        channels = channels_list[-1]

Initial convolution layer that maps the embedding space to channels

190        self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)

ResNet blocks with attention

193        self.mid = nn.Module()
194        self.mid.block_1 = ResnetBlock(channels, channels)
195        self.mid.attn_1 = AttnBlock(channels)
196        self.mid.block_2 = ResnetBlock(channels, channels)

List of top-level blocks

199        self.up = nn.ModuleList()

Create top-level blocks

201        for i in reversed(range(num_resolutions)):

Each top level block consists of multiple ResNet Blocks and up-sampling

203            resnet_blocks = nn.ModuleList()

Add ResNet Blocks

205            for _ in range(n_resnet_blocks + 1):
206                resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
207                channels = channels_list[i]

Top-level block

209            up = nn.Module()
210            up.block = resnet_blocks

Up-sampling at the end of each top level block except the first

212            if i != 0:
213                up.upsample = UpSample(channels)
214            else:
215                up.upsample = nn.Identity()

Prepend to be consistent with the checkpoint

217            self.up.insert(0, up)

Map to image space with a convolution

220        self.norm_out = normalization(channels)
221        self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
  • z is the embedding tensor with shape [batch_size, z_channels, z_height, z_height]
223    def forward(self, z: torch.Tensor):

Map to channels with the initial convolution

229        h = self.conv_in(z)

ResNet blocks with attention

232        h = self.mid.block_1(h)
233        h = self.mid.attn_1(h)
234        h = self.mid.block_2(h)

Top-level blocks

237        for up in reversed(self.up):

ResNet Blocks

239            for block in up.block:
240                h = block(h)

Up-sampling

242            h = up.upsample(h)

Normalize and map to image space

245        h = self.norm_out(h)
246        h = swish(h)
247        img = self.conv_out(h)

250        return img

Gaussian Distribution

253class GaussianDistribution:
  • parameters are the means and log of variances of the embedding of shape [batch_size, z_channels * 2, z_height, z_height]
258    def __init__(self, parameters: torch.Tensor):

Split mean and log of variance

264        self.mean, log_var = torch.chunk(parameters, 2, dim=1)

Clamp the log of variances

266        self.log_var = torch.clamp(log_var, -30.0, 20.0)

Calculate standard deviation

268        self.std = torch.exp(0.5 * self.log_var)
270    def sample(self):

Sample from the distribution

272        return self.mean + self.std * torch.randn_like(self.std)

Attention block

275class AttnBlock(nn.Module):
  • channels is the number of channels
280    def __init__(self, channels: int):
284        super().__init__()

Group normalization

286        self.norm = normalization(channels)

Query, key and value mappings

288        self.q = nn.Conv2d(channels, channels, 1)
289        self.k = nn.Conv2d(channels, channels, 1)
290        self.v = nn.Conv2d(channels, channels, 1)

Final convolution layer

292        self.proj_out = nn.Conv2d(channels, channels, 1)

Attention scaling factor

294        self.scale = channels ** -0.5
  • x is the tensor of shape [batch_size, channels, height, width]
296    def forward(self, x: torch.Tensor):

Normalize x

301        x_norm = self.norm(x)

Get query, key and vector embeddings

303        q = self.q(x_norm)
304        k = self.k(x_norm)
305        v = self.v(x_norm)

Reshape to query, key and vector embeedings from [batch_size, channels, height, width] to [batch_size, channels, height * width]

310        b, c, h, w = q.shape
311        q = q.view(b, c, h * w)
312        k = k.view(b, c, h * w)
313        v = v.view(b, c, h * w)

Compute

316        attn = torch.einsum('bci,bcj->bij', q, k) * self.scale
317        attn = F.softmax(attn, dim=2)

Compute

320        out = torch.einsum('bij,bcj->bci', attn, v)

Reshape back to [batch_size, channels, height, width]

323        out = out.view(b, c, h, w)

Final convolution layer

325        out = self.proj_out(out)

Add residual connection

328        return x + out

Up-sampling layer

331class UpSample(nn.Module):
  • channels is the number of channels
335    def __init__(self, channels: int):
339        super().__init__()

convolution mapping

341        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
  • x is the input feature map with shape [batch_size, channels, height, width]
343    def forward(self, x: torch.Tensor):

Up-sample by a factor of

348        x = F.interpolate(x, scale_factor=2.0, mode="nearest")

Apply convolution

350        return self.conv(x)

Down-sampling layer

353class DownSample(nn.Module):
  • channels is the number of channels
357    def __init__(self, channels: int):
361        super().__init__()

convolution with stride length of to down-sample by a factor of

363        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
  • x is the input feature map with shape [batch_size, channels, height, width]
365    def forward(self, x: torch.Tensor):

Add padding

370        x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)

Apply convolution

372        return self.conv(x)

ResNet Block

375class ResnetBlock(nn.Module):
  • in_channels is the number of channels in the input
  • out_channels is the number of channels in the output
379    def __init__(self, in_channels: int, out_channels: int):
384        super().__init__()

First normalization and convolution layer

386        self.norm1 = normalization(in_channels)
387        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)

Second normalization and convolution layer

389        self.norm2 = normalization(out_channels)
390        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)

in_channels to out_channels mapping layer for residual connection

392        if in_channels != out_channels:
393            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
394        else:
395            self.nin_shortcut = nn.Identity()
  • x is the input feature map with shape [batch_size, channels, height, width]
397    def forward(self, x: torch.Tensor):
402        h = x

First normalization and convolution layer

405        h = self.norm1(h)
406        h = swish(h)
407        h = self.conv1(h)

Second normalization and convolution layer

410        h = self.norm2(h)
411        h = swish(h)
412        h = self.conv2(h)

Map and add residual

415        return self.nin_shortcut(x) + h

Swish activation

418def swish(x: torch.Tensor):
424    return x * torch.sigmoid(x)

Group normalization

This is a helper function, with fixed number of groups and eps .

427def normalization(channels: int):
433    return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)