16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32 pass
35class Embedding(NeoXModule):
n_vocab
是词汇量的大小n_hidden
是嵌入的大小42 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47 super().__init__()
48
49 self.emb = nn.Embedding(n_vocab, n_hidden)
x
是形状的令牌 ID[batch_size, seq_len]
51 def forward(self, x: torch.Tensor):
55 return self.emb(x)
加载检查点的代码
57 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61 with monit.section('Load embedding layer'):
62 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
65class RoPE(nn.Module):
d_rope
是 RoPe 嵌入的要素数量base
是的基础,默认为75 def __init__(self, d_rope: int, base: float = 10_000.):
80 super().__init__()
为要素存储
83 self.theta = None
缓存和
85 self.cos_cached = None
86 self.sin_cached = None
基地
89 self.base = base
ROPE 的要素数量
91 self.d_rope = d_rope
93 @staticmethod
94 def rotate_half(x: torch.Tensor):
100 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101 return torch.cat((-x2, x1), dim=-1)
x
有形状[..., seq, n_heads, d_k]
offset
是的起始位置x
。这是我们缓存先前位置的键和查询的时候103 def forward(self, x: torch.Tensor, offset: int = 0):
获取实际序列长度
111 seq_len = x.shape[-3] + offset
初始化
114 if self.theta is None:
116 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117 self.theta = theta.to(x.device).to(x.dtype)
初始化并缓存
120 if (
121 self.cos_cached is None or
122 seq_len > self.cos_cached.shape[1] or
123 self.cos_cached.device != x.device or
124 self.cos_cached.dtype != x.dtype
125 ):
获取头寸指数
127 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
129 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
133 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
在 fp32 中计算和
136 with autocast(enabled=False):
137 idx_theta2 = idx_theta2.float()
添加头部尺寸
139 self.cos_cached = idx_theta2.cos()[:, None, :]
140 self.sin_cached = idx_theta2.sin()[:, None, :]
缓存它们
143 self.cos_cached = self.cos_cached.to(x.dtype)
144 self.sin_cached = self.sin_cached.to(x.dtype)
拆分要素。我们仅将 RoPe 应用于要d_rope
素
147 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
从缓存中获取 sin 和 cos 值
150 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
162 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
连接未获得 RoPe 嵌入的功能
165 return torch.cat((x_rope, x_pass), dim=-1)
168class AttentionLayer(nn.Module):
n_hidden
嵌入中的特征数量n_heads
注意力头的数量rope_percentage
添加 RoPE 嵌入的功能百分比mask_fill
掩盖注意力矩阵的填充值is_flash_attention
指定是否使用 FlashAttention173 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183 super().__init__()
184
185 self.n_heads = n_heads
186 self.mask_fill = mask_fill
用于查询、键和值的线性图层
189 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
最后的线性层
191 self.output = nn.Linear(n_hidden, n_hidden)
每头特征数
194 d_k = n_hidden // n_heads
绳索嵌入模块
196 self.rope = RoPE(int(d_k * rope_percentage))
注意力缩放系数
199 self.scale = 1 / math.sqrt(d_k)
缓存因果掩码
202 self.causal_mask = None
注意 softmax 模块
205 self.softmax = nn.Softmax(dim=-2)
208 if is_flash_attention:
209 try:
210 from flash_attn.flash_attention import FlashAttention
211 self.flash_attention = FlashAttention()
212 except ImportError:
213 logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214 'Falling back to normal attention', Text.warning)
215 self.flash_attention = None
216 else:
217 self.flash_attention = None
219 def _get_mask(self, attn: torch.Tensor):
查询和密钥长度
227 nq, nk = attn.shape[1:3]
创建遮罩
230 if (
231 self.causal_mask is None or
232 self.causal_mask.shape[0] != nq or
233 self.causal_mask.shape[1] != nk or
234 self.causal_mask.device != attn.device
235 ):
236 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
从缓存中返回
239 return self.causal_mask[None, :, :, None]
x
有形状[batch_size, seq_len, n_hidden]
241 def forward(self, x: torch.Tensor):
获取查询、键和值嵌入(全部串联)。最后一个维度大小将从 n_hidden 更改为->3 x n_hidden
247 qkv = self.qkv_lin(x)
通过将形状改为分成头部[batch_size, seq_len, n_heads, 3 * d_k]
250 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
分为查询、键和值各形状[batch_size, seq_len, n_heads, 3 * d_k]
252 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
如果我们正在缓存之前令牌的状态
255 if get_cache().get('use_cache', False):
获取状态 ID。我们用它来检索以前的状态并存储下一个状态
257 prev_state_id, next_state_id = get_cache().get('state_ids')
如果有缓存
259 if prev_state_id is not None:
获取过去的键和值。这些会有形状[batch_size, prev_seq_len, n_heads, d_k]
261 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
当前嵌入的偏移量
263 offset = k_past.shape[1]
添加绳索嵌入
266 q = self.rope(q, offset=offset)
267 k = self.rope(k, offset=offset)
串联过去
270 k = torch.cat([k_past, k], dim=1)
271 v = torch.cat([v_past, v], dim=1)
272 else:
添加绳索嵌入
274 q = self.rope(q)
275 k = self.rope(k)
保存当前状态
278 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279 else:
没有缓存-只需添加 RoPe 嵌入即可
281 q = self.rope(q)
282 k = self.rope(k)
使用闪光灯注意力
285 if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286 output = self.compute_flash_attention(q, k, v)
否则,请正常注意
288 else:
289 output = self.compute_attention(q, k, v)
从[batch_size, seq_len, n_heads, d_k] to
batch_size、seq_len、n_hidden 进行重塑 `
292 output = output.reshape(*x.shape)
最后的线性层
295 return self.output(output)
297 def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
将它们堆叠成形状[batch_size, seq_len, 3, n_heads, d_k]
299 qkv = torch.stack((q, k, v), dim=2)
300 d_k = qkv.shape[-1]
301 if d_k <= 32:
302 pad = 32 - d_k
303 elif d_k <= 64:
304 pad = 64 - d_k
305 elif d_k <= 128:
306 pad = 128 - d_k
307 else:
308 raise ValueError(f'Head size {d_k} too large for flash attention')
309
310 if pad > 0:
311 qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313 output, _ = self.flash_attention(qkv, causal=True)
输出的形状是这样的[batch_size, seq_len, n_heads, d_k + padding]
315 output = output[:, :, :, :d_k]
316
317 return output
319 def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
禁用自动投射到 fp16 以进行注意力计算
321 with autocast(enabled=False):
322 if q.dtype == torch.float16:
如果当前数据类型为 fp16,则转换为 fp32
324 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325 else:
不要为 bfloat 进行投射
327 attn = torch.einsum('bihk,bjhk->bijh', q, k)
缩放注意力
330 attn = attn * self.scale
获得因果口罩
333 mask = self._get_mask(attn)
涂抹面膜
335 attn.masked_fill_(mask, self.mask_fill)
注意 softmax
338 attn = self.softmax(attn)
获取注意力加权值
341 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343 return output
346class FFNLayer(nn.Module):
n_hidden
是嵌入的大小351 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355 super().__init__()
356
357 if not d_ff:
358 d_ff = n_hidden * 4
扩展线性层
361 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
GELU 激活
363 self.activation = nn.GELU()
收缩线性层
365 self.dense_h4_h = nn.Linear(d_ff, n_hidden)
x
有形状[batch_size, seq_len, n_hidden]
367 def forward(self, x: torch.Tensor):
371 x = self.dense_h_h4(x)
372 x = self.activation(x)
373 x = self.dense_h4_h(x)
374
375 return x
378class TransformerLayer(NeoXModule):
383 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392 super().__init__()
注意之前的图层规范化
395 self.pre_ln_attn = nn.LayerNorm(n_hidden)
FFN 之前的层标准化
397 self.pre_ln_ffn = nn.LayerNorm(n_hidden)
注意层
400 self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)
FFN 层
402 self.ffn = FFNLayer(n_hidden)
x
是形状的嵌入[batch_size, seq_len, n_hidden]
404 def forward(self, x: torch.Tensor):
剩余连接
410 residual = x
NeoX 并行运行注意力和前馈网络
412 attn = self.attention(self.pre_ln_attn(x))
413 ffn = self.ffn(self.pre_ln_ffn(x))
添加它们和剩余的连接
415 return attn + ffn + residual
加载检查点的代码
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421 with monit.section('Load transformer layer'):
注意力输出变换
423 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
注意力查询、关键和价值转换
427 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
注意之前先进行分层规范
431 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
FFN 第二次变换
435 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
FFN 首次改造
439 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
FFN 之前的分层规范
443 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
447class FinalNorm(NeoXModule):
n_hidden
是嵌入的大小452 def __init__(self, n_hidden: int = 6_144):
456 super().__init__()
457
458 self.ln = nn.LayerNorm(n_hidden)
x
是形状的嵌入[batch_size, seq_len, n_hidden]
460 def forward(self, x: torch.Tensor):
464 return self.ln(x)
加载检查点的代码
466 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470 with monit.section('Load final normalization layer'):
471 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
读出层
475class ReadoutLayer(NeoXModule):
n_hidden
是嵌入的大小n_vocab
是词汇量的大小480 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485 super().__init__()
486
487 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
x
是形状的嵌入[batch_size, seq_len, n_hidden]
489 def forward(self, x: torch.Tensor):
493 return self.linear(x)
加载检查点的代码
495 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499 with monit.section('Load final linear layer'):
500 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504 pre_created_layers: Dict[Any, Optional[NeoXModule]]
图层的生成顺序与检查点的生成顺序相同。
它在图层不可用None
时给出;我们将图层索引用作 NeoX,并且在实现中不需要两个转换层。
n_vocab
是词汇表中的代币数量n_hidden
是嵌入中的特征数量n_layers
是变压器层数n_heads
是注意力头的数量filter_layers
是要使用的图层集。如果没有,则将使用所有图层。这用于测试层数较少的模型的较小版本is_clone_layers
指定是否克隆变压器层(快一点)dtype
是模型的数据类型device
是模型的设备is_llm_int8
指定是否使用 int8 量化llm_int8_threshold
是用于分离异常特征的阈值is_flash_attention
指定是否使用 FlashAttention506 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507 n_layers: int = 44, n_heads: int = 64,
508 filter_layers: Optional[Set] = None,
509 is_clone_layers: bool = True,
510 dtype: torch.dtype = torch.float,
511 device: torch.device = torch.device('cpu'),
512 is_llm_int8: bool = False,
513 llm_int8_threshold: float = 6.0,
514 is_flash_attention: bool = False
515 ):
538 if filter_layers is None:
539 filter_layers = set(range(n_layers + 3))
540
541 self.n_vocab = n_vocab
542 self.n_hidden = n_hidden
543 self.n_layers = n_layers
544 self.n_heads = n_heads
545 self.filter_layers = filter_layers
546 self.is_clone_layers = is_clone_layers
547 self.dtype = dtype
548 self.device = device
549 self.is_llm_int8 = is_llm_int8
550 self.llm_int8_threshold = llm_int8_threshold
551 self.is_flash_attention = is_flash_attention
552
553 self.pre_created_layers = dict(
554 transformer_layer=None,
555 )
557 def _prepare_layer(self, layer: NeoXModule):
566 return layer.to(self.device, self.dtype)
此函数在加载检查点后实现层转换。
目前,它仅应用 int8 量化。
layer
是要准备的图层is_llm_int8
指定是否使用 int8 量化device
是该型号的设备llm_int8_threshold
是用于分隔异常值要素的阈值返回准备好的图层
568 @torch.no_grad()
569 def post_load_prepare(self, layer: NeoXModule, *,
570 is_llm_int8: bool = None,
571 device: torch.device = None,
572 llm_int8_threshold: float = None,
573 ):
如果未指定,则获取默认值
591 if is_llm_int8 is None:
592 is_llm_int8 = self.is_llm_int8
593 if device is None:
594 device = self.device
595 if llm_int8_threshold is None:
596 llm_int8_threshold = self.llm_int8_threshold
如果不使用 int8 量化则跳过
599 if not is_llm_int8:
600 return layer
仅转换变压器层中的线性层
603 if not isinstance(layer, TransformerLayer):
604 return layer
607 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
转换线性图层
610 with monit.section('Convert to int8'):
611 layer.attention.output = make_llm_int8_linear(layer.attention.output,
612 device=device,
613 threshold=llm_int8_threshold)
614 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615 device=device,
616 threshold=llm_int8_threshold)
617 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618 device=device,
619 threshold=llm_int8_threshold)
620 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621 device=device,
622 threshold=llm_int8_threshold)
624 return layer
626 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638 if not self.is_clone_layers:
639 return self._prepare_layer(creator())
640
641 if self.pre_created_layers[name] is None:
642 self.pre_created_layers[name] = self._prepare_layer(creator())
643
644 layer = copy.deepcopy(self.pre_created_layers[name])
645 return layer
647 def _create_transformer_layer(self):
648 return self._create_and_cache_layer(
649 'transformer_layer',
650 lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651 )
653 def _create_embedding_layer(self):
654 return Embedding(self.n_vocab, self.n_hidden)
656 def _create_final_norm_layer(self):
657 return FinalNorm(self.n_hidden)
659 def _create_readout_layer(self):
660 return ReadoutLayer(self.n_hidden, self.n_vocab)
662 @torch.no_grad()
663 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
嵌入层
668 if 0 in self.filter_layers:
669 with monit.section('Embedding layer'):
670 layer = self._prepare_layer(self._create_embedding_layer())
671 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
变压器层
674 for i in range(self.n_layers):
变压器层
676 if i + 1 in self.filter_layers:
677 with monit.section(f'Transformer Layer {i}'):
678 yield self._create_transformer_layer(), \
679 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680 f'layer_{i + 2 :02d}-model_01-model_states.pt')
最终归一化层
683 if self.n_layers + 1 in self.filter_layers:
684 with monit.section('Final norm layer'):
685 layer = self._prepare_layer(self._create_final_norm_layer())
686 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
读出层
689 if self.n_layers + 2 in self.filter_layers:
690 with monit.section('Readout layer'):
691 layer = self._prepare_layer(self._create_readout_layer())
692 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694 for k in self.pre_created_layers.keys():
695 self.pre_created_layers[k] = None
697 @property
698 def total_layers(self):
702 return self.n_layers + 3
704 @torch.no_grad()
705 def load(self) -> Generator[NeoXModule, None, None]:
709 with monit.section("Layers"):
710 for i, (layer, files) in enumerate(self.get_layers()):
711 if files is not None:
712 layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714 layer = self.post_load_prepare(layer)
715
716 monit.progress(min(0.99, (i + 1) / self.total_layers))
717 yield layer