用于稳定扩散 U-Net 的变压器

这实现了 U-Net 中使用的变压器模块,它提供

我们保持了 compvis/Stable-Difusi on 的模型定义和命名不变,这样我们就可以直接加载检查点。

19from typing import Optional
20
21import torch
22import torch.nn.functional as F
23from torch import nn

空间变压器

26class SpatialTransformer(nn.Module):
  • channels 是功能图中的频道数
  • n_heads 是注意力头的数量
  • n_layers 是变压器层数
  • d_cond 是条件嵌入的大小
31    def __init__(self, channels: int, n_heads: int, n_layers: int, d_cond: int):
38        super().__init__()

初始群组标准化

40        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)

初始卷积

42        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)

变压器层

45        self.transformer_blocks = nn.ModuleList(
46            [BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond=d_cond) for _ in range(n_layers)]
47        )

最后的卷积

50        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
  • x 是形状的特征图[batch_size, channels, height, width]
  • cond 是形状的条件嵌入[batch_size, n_cond, d_cond]
52    def forward(self, x: torch.Tensor, cond: torch.Tensor):

塑造身形[batch_size, channels, height, width]

58        b, c, h, w = x.shape

用于剩余连接

60        x_in = x

标准化

62        x = self.norm(x)

初始卷积

64        x = self.proj_in(x)

从到[batch_size, channels, height, width] 转置和重塑[batch_size, height * width, channels]

67        x = x.permute(0, 2, 3, 1).view(b, h * w, c)

应用变压器层

69        for block in self.transformer_blocks:
70            x = block(x, cond)

重塑形状并从变换[batch_size, height * width, channels][batch_size, channels, height, width]

73        x = x.view(b, h, w, c).permute(0, 3, 1, 2)

最后的卷积

75        x = self.proj_out(x)

添加残差

77        return x + x_in

变压器层

80class BasicTransformerBlock(nn.Module):
  • d_model 是输入嵌入大小
  • n_heads 是注意力头的数量
  • d_head 是注意力头的大小
  • d_cond 是条件嵌入的大小
85    def __init__(self, d_model: int, n_heads: int, d_head: int, d_cond: int):
92        super().__init__()

自我注意力层和预规范层

94        self.attn1 = CrossAttention(d_model, d_model, n_heads, d_head)
95        self.norm1 = nn.LayerNorm(d_model)

交叉注意力层和预规范层

97        self.attn2 = CrossAttention(d_model, d_cond, n_heads, d_head)
98        self.norm2 = nn.LayerNorm(d_model)

前馈网络和预规范层

100        self.ff = FeedForward(d_model)
101        self.norm3 = nn.LayerNorm(d_model)
  • x 是形状的输入嵌入[batch_size, height * width, d_model]
  • cond 是形状的条件嵌入[batch_size, n_cond, d_cond]
103    def forward(self, x: torch.Tensor, cond: torch.Tensor):

自我注意力

109        x = self.attn1(self.norm1(x)) + x

交叉注意力与调节

111        x = self.attn2(self.norm2(x), cond=cond) + x

前馈网络

113        x = self.ff(self.norm3(x)) + x

115        return x

交叉注意力层

当未指定条件嵌入时,这会回归到自我注意力。

118class CrossAttention(nn.Module):
125    use_flash_attention: bool = False
  • d_model 是输入嵌入大小
  • n_heads 是注意力头的数量
  • d_head 是注意力头的大小
  • d_cond 是条件嵌入的大小
  • is_inplace 指定是否就地执行注意力 softmax 计算以节省内存
127    def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
136        super().__init__()
137
138        self.is_inplace = is_inplace
139        self.n_heads = n_heads
140        self.d_head = d_head

注意力缩放系数

143        self.scale = d_head ** -0.5

查询、键和值映射

146        d_attn = d_head * n_heads
147        self.to_q = nn.Linear(d_model, d_attn, bias=False)
148        self.to_k = nn.Linear(d_cond, d_attn, bias=False)
149        self.to_v = nn.Linear(d_cond, d_attn, bias=False)

最后的线性层

152        self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))

设置闪光警示。Flash 注意只有在安装并设置CrossAttention.use_flash_attention 为时才会使用True

157        try:

你可以通过克隆他们的 Github 存储库 https://github.com/HazyResearch/flash-attention 然后运行来安装 Flash 注意力python setup.py install

161            from flash_attn.flash_attention import FlashAttention
162            self.flash = FlashAttention()

设置按比例缩放点产品注意力的比例。

164            self.flash.softmax_scale = self.scale

None 如果未安装,则设置为

166        except ImportError:
167            self.flash = None
  • x 是形状的输入嵌入[batch_size, height * width, d_model]
  • cond 是形状的条件嵌入[batch_size, n_cond, d_cond]
169    def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):

如果cond 是,None 我们进行自我关注

176        has_cond = cond is not None
177        if not has_cond:
178            cond = x

获取查询向量、键向量和值向量

181        q = self.to_q(x)
182        k = self.to_k(cond)
183        v = self.to_v(cond)

如果闪光灯注意力可用且头部大小小于或等于,请使用闪光警示128

186        if CrossAttention.use_flash_attention and self.flash is not None and not has_cond and self.d_head <= 128:
187            return self.flash_attention(q, k, v)

否则,回退到正常的注意力上

189        else:
190            return self.normal_attention(q, k, v)

Flash 注意

  • q 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
  • k 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
  • v 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
  • 192    def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

    沿序列轴获取批量大小和元素数量 (width * height )

    202        batch_size, seq_len, _ = q.shape

    堆叠qkv 向量以获得闪光注意力,以获得单个形状张量[batch_size, seq_len, 3, n_heads * d_head]

    206        qkv = torch.stack((q, k, v), dim=2)

    分开脑袋

    208        qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)

    Flash 注意力适用于头部尺寸32 12864 而且,因此我们必须垫住头部才能适合这个尺寸。

    212        if self.d_head <= 32:
    213            pad = 32 - self.d_head
    214        elif self.d_head <= 64:
    215            pad = 64 - self.d_head
    216        elif self.d_head <= 128:
    217            pad = 128 - self.d_head
    218        else:
    219            raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')

    垫住头部

    222        if pad:
    223            qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1)

    计算注意力这给出了形状的张量[batch_size, seq_len, n_heads, d_padded]

    228        out, _ = self.flash(qkv)

    截断多余的头部尺寸

    230        out = out[:, :, :, :self.d_head]

    重塑为[batch_size, seq_len, n_heads * d_head]

    232        out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)

    [batch_size, height * width, d_model] 使用线性图层映射到

    235        return self.to_out(out)

    正常注意力

    • q 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
    • k 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
  • v 是分割头部之前的查询向量,形状为[batch_size, seq, d_attn]
  • 237    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

    将它们分成形状的头部[batch_size, seq_len, n_heads, d_head]

    247        q = q.view(*q.shape[:2], self.n_heads, -1)
    248        k = k.view(*k.shape[:2], self.n_heads, -1)
    249        v = v.view(*v.shape[:2], self.n_heads, -1)

    计算注意力

    252        attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale

    计算 softmax

    256        if self.is_inplace:
    257            half = attn.shape[0] // 2
    258            attn[half:] = attn[half:].softmax(dim=-1)
    259            attn[:half] = attn[:half].softmax(dim=-1)
    260        else:
    261            attn = attn.softmax(dim=-1)

    计算注意力输出

    265        out = torch.einsum('bhij,bjhd->bihd', attn, v)

    重塑为[batch_size, height * width, n_heads * d_head]

    267        out = out.reshape(*out.shape[:2], -1)

    [batch_size, height * width, d_model] 使用线性图层映射到

    269        return self.to_out(out)

    前馈网络

    272class FeedForward(nn.Module):
    • d_model 是输入嵌入大小
    • d_mult 是隐藏层大小的乘法因子
    277    def __init__(self, d_model: int, d_mult: int = 4):
    282        super().__init__()
    283        self.net = nn.Sequential(
    284            GeGLU(d_model, d_model * d_mult),
    285            nn.Dropout(0.),
    286            nn.Linear(d_model * d_mult, d_model)
    287        )
    289    def forward(self, x: torch.Tensor):
    290        return self.net(x)

    激活 GegLU

    293class GeGLU(nn.Module):
    300    def __init__(self, d_in: int, d_out: int):
    301        super().__init__()

    组合线性投影

    303        self.proj = nn.Linear(d_in, d_out * 2)
    305    def forward(self, x: torch.Tensor):

    获取

    307        x, gate = self.proj(x).chunk(2, dim=-1)

    309        return x * F.gelu(gate)