GPT-NEOX 型号

以下是 GPT-NEOX 模型层的代码和加载 20B 检查点的代码。

图层load_state 中的方法加载该层的检查点。检查点加载助手已启用 checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass

嵌入层

这是一个标准的嵌入层,其中包含用于加载检查点的代码。

35class Embedding(NeoXModule):
  • n_vocab 是词汇量的大小
  • n_hidden 是嵌入的大小
42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x 是形状的令牌 ID[batch_size, seq_len]
51    def forward(self, x: torch.Tensor):
55        return self.emb(x)

加载检查点的代码

57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

旋转位置嵌入

GPT-NEOX 使用旋转位置嵌入(RoP)

我们在这里注释了 RoPe 的实现,并附上了更多关于理论的注释。

65class RoPE(nn.Module):
  • d_rope 是 RoPe 嵌入的要素数量
  • base 是的基础,默认为
75    def __init__(self, d_rope: int, base: float = 10_000.):
80        super().__init__()

为要素存储

83        self.theta = None

缓存

85        self.cos_cached = None
86        self.sin_cached = None

基地

89        self.base = base

ROPE 的要素数量

91        self.d_rope = d_rope

旋转要素

93    @staticmethod
94    def rotate_half(x: torch.Tensor):
100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)
  • x 有形状[..., seq, n_heads, d_k]
  • offset 是的起始位置x 。这是我们缓存先前位置的键和查询的时候
103    def forward(self, x: torch.Tensor, offset: int = 0):

获取实际序列长度

111        seq_len = x.shape[-3] + offset

初始化

114        if self.theta is None:

116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)

初始化缓存

120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):

获取头寸指数

127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

连接这样我们就有 row

133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

在 fp32 中计算

136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()

添加头部尺寸

139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]

缓存它们

143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)

拆分要素。我们仅将 RoPe 应用于要d_rope

147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

从缓存中获取 sin 和 cos 值

150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

绳索嵌入

对于

162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

连接未获得 RoPe 嵌入的功能

165        return torch.cat((x_rope, x_pass), dim=-1)

注意层

168class AttentionLayer(nn.Module):
  • n_hidden 嵌入中的特征数量
  • n_heads 注意力头的数量
  • rope_percentage 添加 RoPE 嵌入的功能百分比
  • mask_fill 掩盖注意力矩阵的填充值
  • is_flash_attention 指定是否使用 FlashAttention
173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fill

用于查询、键和值的线性图层

189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

最后的线性层

191        self.output = nn.Linear(n_hidden, n_hidden)

每头特征数

194        d_k = n_hidden // n_heads

绳索嵌入模块

196        self.rope = RoPE(int(d_k * rope_percentage))

注意力缩放系数

199        self.scale = 1 / math.sqrt(d_k)

缓存因果掩码

202        self.causal_mask = None

注意 softmax 模块

205        self.softmax = nn.Softmax(dim=-2)
208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None

计算因果掩码

219    def _get_mask(self, attn: torch.Tensor):

查询和密钥长度

227        nq, nk = attn.shape[1:3]

创建遮罩

230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

从缓存中返回

239        return self.causal_mask[None, :, :, None]
  • x 有形状[batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):

获取查询、键和值嵌入(全部串联)。最后一个维度大小将从 n_hidden 更改为->3 x n_hidden

247        qkv = self.qkv_lin(x)

通过将形状改为分成头部[batch_size, seq_len, n_heads, 3 * d_k]

250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

分为查询、键和值各形状[batch_size, seq_len, n_heads, 3 * d_k]

252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

如果我们正在缓存之前令牌的状态

255        if get_cache().get('use_cache', False):

获取状态 ID。我们用它来检索以前的状态并存储下一个状态

257            prev_state_id, next_state_id = get_cache().get('state_ids')

如果有缓存

259            if prev_state_id is not None:

获取过去的键和值。这些会有形状[batch_size, prev_seq_len, n_heads, d_k]

261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

当前嵌入的偏移量

263                offset = k_past.shape[1]

添加绳索嵌入

266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)

串联过去

270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:

添加绳索嵌入

274                q = self.rope(q)
275                k = self.rope(k)

保存当前状态

278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:

没有缓存-只需添加 RoPe 嵌入即可

281            q = self.rope(q)
282            k = self.rope(k)

使用闪光灯注意力

285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)

否则,请正常注意

288        else:
289            output = self.compute_attention(q, k, v)

[batch_size, seq_len, n_heads, d_k] to batch_size、seq_len、n_hidden 进行重塑 `

292        output = output.reshape(*x.shape)

最后的线性层

295        return self.output(output)
297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

将它们堆叠成形状[batch_size, seq_len, 3, n_heads, d_k]

299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)

输出的形状是这样的[batch_size, seq_len, n_heads, d_k + padding]

315        output = output[:, :, :, :d_k]
316
317        return output
319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

禁用自动投射到 fp16 以进行注意力计算

321        with autocast(enabled=False):
322            if q.dtype == torch.float16:

如果当前数据类型为 fp16,则转换为 fp32

324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:

不要为 bfloat 进行投射

327                attn = torch.einsum('bihk,bjhk->bijh', q, k)

缩放注意力

330            attn = attn * self.scale

获得因果口罩

333            mask = self._get_mask(attn)

涂抹面膜

335            attn.masked_fill_(mask, self.mask_fill)

注意 softmax

338            attn = self.softmax(attn)

获取注意力加权值

341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output

前馈网络

346class FFNLayer(nn.Module):
  • n_hidden 是嵌入的大小
351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4

扩展线性层

361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELU 激活

363        self.activation = nn.GELU()

收缩线性层

365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x 有形状[batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):
371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x

变压器层

378class TransformerLayer(NeoXModule):
  • n_hidden 是嵌入大小
  • n_heads 是头数
  • is_flash_attention 指定是否使用 FlashAttention

Out 的实现不包括退出

383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392        super().__init__()

注意之前的图层规范化

395        self.pre_ln_attn = nn.LayerNorm(n_hidden)

FFN 之前的层标准化

397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

注意层

400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)

FFN 层

402        self.ffn = FFNLayer(n_hidden)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):

剩余连接

410        residual = x

NeoX 并行运行注意力和前馈网络

412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))

添加它们和剩余的连接

415        return attn + ffn + residual

加载检查点的代码

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load transformer layer'):

注意力输出变换

423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

注意力查询、关键和价值转换

427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

注意之前先进行分层规范

431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFN 第二次变换

435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFN 首次改造

439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

FFN 之前的分层规范

443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

最终归一化层

447class FinalNorm(NeoXModule):
  • n_hidden 是嵌入的大小
452    def __init__(self, n_hidden: int = 6_144):
456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):
464        return self.ln(x)

加载检查点的代码

466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

读出层

475class ReadoutLayer(NeoXModule):
  • n_hidden 是嵌入的大小
  • n_vocab 是词汇量的大小
480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x 是形状的嵌入[batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):
493        return self.linear(x)

加载检查点的代码

495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]

用于创建图层的生成器

图层的生成顺序与检查点的生成顺序相同。

它在图层不可用None 时给出;我们将图层索引用作 NeoX,并且在实现中不需要两个转换层。

  • n_vocab 是词汇表中的代币数量
  • n_hidden 是嵌入中的特征数量
  • n_layers 是变压器层数
  • n_heads 是注意力头的数量
  • filter_layers 是要使用的图层集。如果没有,则将使用所有图层。这用于测试层数较少的模型的较小版本
  • is_clone_layers 指定是否克隆变压器层(快一点)
  • dtype 是模型的数据类型
  • device 是模型的设备
  • is_llm_int8 指定是否使用 int8 量化
  • llm_int8_threshold用于分离异常特征的阈值
  • is_flash_attention 指定是否使用 FlashAttention
506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):
538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )

准备图层以供使用

我们将图层移动到设备并将其转换为正确的数据类型

  • layer 是要准备的图层
  • 返回准备好的图层

557    def _prepare_layer(self, layer: NeoXModule):
566        return layer.to(self.device, self.dtype)

加载检查点后的图层变换

此函数在加载检查点后实现层转换。

目前,它仅应用 int8 量化。

  • layer 是要准备的图层
  • is_llm_int8 指定是否使用 int8 量化
  • device 是该型号的设备
  • llm_int8_threshold用于分隔异常值要素的阈值
  • 返回准备好的图层

568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):

如果未指定,则获取默认值

591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_threshold

如果不使用 int8 量化则跳过

599        if not is_llm_int8:
600            return layer

仅转换变压器层中的线性层

603        if not isinstance(layer, TransformerLayer):
604            return layer

使用在实用程序make_llm_int8_linear 定义。

607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

转换线性图层

610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)

624        return layer

创建和缓存图层

复制缓存图层比初始化新图层要快,因为初始化参数需要时间。

  • name 是层的名称
  • creator 是创建图层的函数
  • 返回创建的图层或缓存图层的副本

626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer
647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )
653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)
656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)
659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)

获取图层的生成器

662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

嵌入层

668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

变压器层

674        for i in range(self.n_layers):

变压器层

676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

最终归一化层

683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

读出层

689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None

返回总层数

697    @property
698    def total_layers(self):
702        return self.n_layers + 3

用于加载层的生成器

704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:
709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer