Transformer for Stable Diffusion U-Net

This implements the transformer module used in U-Net that gives

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

19from typing import Optional
20
21import torch
22import torch.nn.functional as F
23from torch import nn

Spatial Transformer

26class SpatialTransformer(nn.Module):
  • channels is the number of channels in the feature map
  • n_heads is the number of attention heads
  • n_layers is the number of transformer layers
  • d_cond is the size of the conditional embedding
31    def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
38        super().__init__()

Initial group normalization

40        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

Initial convolution

42        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)

Transformer layers

45        self.transformer_blocks = nn.ModuleList(
46            [BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
47        )

Final convolution

50        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
  • x is the feature map of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
52    def forward(self, x: torch.Tensor, cond: torch.Tensor):

Get shape [batch_size, channels, height, width]

58        b, c, h, w = x.shape

For residual connection

60        x_in = x

Normalize

62        x = self.norm(x)

Initial convolution

64        x = self.proj_in(x)

Transpose and reshape from [batch_size, channels, height, width] to [batch_size, height * width, channels]

67        x = x.permute(0, 2, 3, 1).view(b, h * w, c)

Apply the transformer layers

69        for block in self.transformer_blocks:
70            x = block(x, cond)

Reshape and transpose from [batch_size, height * width, channels] to [batch_size, channels, height, width]

73        x = x.view(b, h, w, c).permute(0, 3, 1, 2)

Final convolution

75        x = self.proj_out(x)

Add residual

77        return x + x_in

Transformer Layer

80class BasicTransformerBlock(nn.Module):
  • d_model is the input embedding size
  • n_heads is the number of attention heads
  • d_head is the size of a attention head
  • d_cond is the size of the conditional embeddings
85    def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
92        super().__init__()

Self-attention layer and pre-norm layer

94        self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
95        self.norm1 = nn.LayerNorm(d_model)

Cross attention layer and pre-norm layer

97        self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
98        self.norm2 = nn.LayerNorm(d_model)

Feed-forward network and pre-norm layer

100        self.ff = FeedForward(d_model)
101        self.norm3 = nn.LayerNorm(d_model)
  • x are the input embeddings of shape [batch_size, height * width, d_model]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
103    def forward(self, x: torch.Tensor, cond: torch.Tensor):

Self attention

109        x = self.attn1(self.norm1(x)) + x

Cross-attention with conditioning

111        x = self.attn2(self.norm2(x), cond=cond) + x

Feed-forward network

113        x = self.ff(self.norm3(x)) + x

115        return x

Cross Attention Layer

This falls-back to self-attention when conditional embeddings are not specified.

118class CrossAttention(nn.Module):
125    use_flash_attention: bool = False
  • d_model is the input embedding size
  • n_heads is the number of attention heads
  • d_head is the size of a attention head
  • d_cond is the size of the conditional embeddings
  • is_inplace specifies whether to perform the attention softmax computation inplace to save memory
127    def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
136        super().__init__()
137
138        self.is_inplace = is_inplace
139        self.n_heads = n_heads
140        self.d_head = d_head

Attention scaling factor

143        self.scale = d_head ** -0.5

Query, key and value mappings

146        d_attn = d_head * n_heads
147        self.to_q = nn.Linear(d_model, d_attn, bias=False)
148        self.to_k = nn.Linear(d_cond, d_attn, bias=False)
149        self.to_v = nn.Linear(d_cond, d_attn, bias=False)

Final linear layer

152        self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))

Setup flash attention. Flash attention is only used if it's installed and CrossAttention.use_flash_attention is set to True .

157        try:

You can install flash attention by cloning their Github repo, https://github.com/HazyResearch/flash-attention and then running python setup.py install

161            from flash_attn.flash_attention import FlashAttention
162            self.flash = FlashAttention()

Set the scale for scaled dot-product attention.

164            self.flash.softmax_scale = self.scale

Set to None if it's not installed

166        except ImportError:
167            self.flash = None
  • x are the input embeddings of shape [batch_size, height * width, d_model]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
169    def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):

If cond is None we perform self attention

176        has_cond = cond is not None
177        if not has_cond:
178            cond = x

Get query, key and value vectors

181        q = self.to_q(x)
182        k = self.to_k(cond)
183        v = self.to_v(cond)

Use flash attention if it's available and the head size is less than or equal to 128

186        if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
187            return self.flash_attention(q, k, v)

Otherwise, fallback to normal attention

189        else:
190            return self.normal_attention(q, k, v)

Flash Attention

  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
192    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

Get batch size and number of elements along sequence axis (width * height )

202        batch_size, seq_len, _ = q.shape

Stack q , k , v vectors for flash attention, to get a single tensor of shape [batch_size, seq_len, 3, n_heads * d_head]

206        qkv = torch.stack((q, k, v), dim=2)

Split the heads

208        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)

Flash attention works for head sizes 32 , 64 and 128 , so we have to pad the heads to fit this size.

212        if self.d_head <= 32:
213            pad = 32 - self.d_head
214        elif self.d_head <= 64:
215            pad = 64 - self.d_head
216        elif self.d_head <= 128:
217            pad = 128 - self.d_head
218        else:
219            raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')

Pad the heads

222        if pad:
223            qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)

Compute attention This gives a tensor of shape [batch_size, seq_len, n_heads, d_padded]

228        out, _ = self.flash(qkv)

Truncate the extra head size

230        out = out[:, :, :, :self.d_head]

Reshape to [batch_size, seq_len, n_heads * d_head]

232        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)

Map to [batch_size, height * width, d_model] with a linear layer

235        return self.to_out(out)

Normal Attention

  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
237    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

Split them to heads of shape [batch_size, seq_len, n_heads, d_head]

247        q = q.view(*q.shape[:2], self.n_heads, -1)
248        k = k.view(*k.shape[:2], self.n_heads, -1)
249        v = v.view(*v.shape[:2], self.n_heads, -1)

Calculate attention

252        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale

Compute softmax

256        if self.is_inplace:
257            half = attn.shape[0] // 2
258            attn[half:] = attn[half:].softmax(dim=-1)
259            attn[:half] = attn[:half].softmax(dim=-1)
260        else:
261            attn = attn.softmax(dim=-1)

Compute attention output

265        out = torch.einsum('bhij,bjhd->bihd', attn, v)

Reshape to [batch_size, height * width, n_heads * d_head]

267        out = out.reshape(*out.shape[:2], -1)

Map to [batch_size, height * width, d_model] with a linear layer

269        return self.to_out(out)

Feed-Forward Network

272class FeedForward(nn.Module):
  • d_model is the input embedding size
  • d_mult is multiplicative factor for the hidden layer size
277    def __init__(self, d_model: int, d_mult: int = 4):
282        super().__init__()
283        self.net = nn.Sequential(
284            GeGLU(d_model, d_model * d_mult),
285            nn.Dropout(0.),
286            nn.Linear(d_model * d_mult, d_model)
287        )
289    def forward(self, x: torch.Tensor):
290        return self.net(x)

GeGLU Activation

293class GeGLU(nn.Module):
300    def __init__(self, d_in: int, d_out: int):
301        super().__init__()

Combined linear projections and

303        self.proj = nn.Linear(d_in, d_out * 2)
305    def forward(self, x: torch.Tensor):

Get and

307        x, gate = self.proj(x).chunk(2, dim=-1)

309        return x * F.gelu(gate)