これは、GPT-Neoxモデルのレイヤー用のコードと20Bのチェックポイントをロードするコードです。
load_state
レイヤー内のメソッドは、そのレイヤーのチェックポイントをロードします。チェックポイントロードヘルパーがオンになっています checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32 pass
35class Embedding(NeoXModule):
n_vocab
ボキャブラリーの大きさですn_hidden
は埋め込みのサイズです42 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47 super().__init__()
48
49 self.emb = nn.Embedding(n_vocab, n_hidden)
x
形状のトークンIDです [batch_size, seq_len]
51 def forward(self, x: torch.Tensor):
55 return self.emb(x)
チェックポイントをロードするコード
57 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61 with monit.section('Load embedding layer'):
62 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
65class RoPE(nn.Module):
d_rope
RoPE 埋め込みの機能の数ですbase
がの基底で、デフォルトは 75 def __init__(self, d_rope: int, base: float = 10_000.):
80 super().__init__()
機能用に保存するには
83 self.theta = None
キャッシュと
85 self.cos_cached = None
86 self.sin_cached = None
のベース
89 self.base = base
RoPE の機能の数
91 self.d_rope = d_rope
93 @staticmethod
94 def rotate_half(x: torch.Tensor):
100 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101 return torch.cat((-x2, x1), dim=-1)
x
形がある [..., seq, n_heads, d_k]
offset
x
の開始位置です。これは、以前のポジションのキーとクエリをキャッシュしたときです103 def forward(self, x: torch.Tensor, offset: int = 0):
実際のシーケンス長を取得
111 seq_len = x.shape[-3] + offset
[初期化]
114 if self.theta is None:
116 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117 self.theta = theta.to(x.device).to(x.dtype)
初期化とキャッシュ
120 if (
121 self.cos_cached is None or
122 seq_len > self.cos_cached.shape[1] or
123 self.cos_cached.device != x.device or
124 self.cos_cached.dtype != x.dtype
125 ):
位置インデックスを取得
127 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
129 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
133 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
計算して fp32 で
136 with autocast(enabled=False):
137 idx_theta2 = idx_theta2.float()
頭部寸法を追加
139 self.cos_cached = idx_theta2.cos()[:, None, :]
140 self.sin_cached = idx_theta2.sin()[:, None, :]
それらをキャッシュする
143 self.cos_cached = self.cos_cached.to(x.dtype)
144 self.sin_cached = self.sin_cached.to(x.dtype)
機能を分割してください。RoPE d_rope
は機能にのみ適用されます
147 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
キャッシュから sin と cos の値を取得
150 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
162 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
RoPe 埋め込みに対応していなかった機能との連携
165 return torch.cat((x_rope, x_pass), dim=-1)
168class AttentionLayer(nn.Module):
n_hidden
埋め込みに含まれる機能の数n_heads
アテンション・ヘッドの数rope_percentage
RoPe 埋め込みを追加する機能の割合mask_fill
アテンション・マトリックスのマスキング・フィル値is_flash_attention
フラッシュアテンションを使用するかどうかを指定します173 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183 super().__init__()
184
185 self.n_heads = n_heads
186 self.mask_fill = mask_fill
クエリ、キー、値の線形レイヤー
189 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
最終線形レイヤー
191 self.output = nn.Linear(n_hidden, n_hidden)
ヘッドあたりの機能数
194 d_k = n_hidden // n_heads
RoPE 埋め込みモジュール
196 self.rope = RoPE(int(d_k * rope_percentage))
アテンションスケーリングファクター
199 self.scale = 1 / math.sqrt(d_k)
因果マスクをキャッシュするには
202 self.causal_mask = None
アテンションソフトマックスモジュール
205 self.softmax = nn.Softmax(dim=-2)
208 if is_flash_attention:
209 try:
210 from flash_attn.flash_attention import FlashAttention
211 self.flash_attention = FlashAttention()
212 except ImportError:
213 logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214 'Falling back to normal attention', Text.warning)
215 self.flash_attention = None
216 else:
217 self.flash_attention = None
219 def _get_mask(self, attn: torch.Tensor):
クエリとキーの長さ
227 nq, nk = attn.shape[1:3]
マスク作成
230 if (
231 self.causal_mask is None or
232 self.causal_mask.shape[0] != nq or
233 self.causal_mask.shape[1] != nk or
234 self.causal_mask.device != attn.device
235 ):
236 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
キャッシュから戻る
239 return self.causal_mask[None, :, :, None]
x
形がある [batch_size, seq_len, n_hidden]
241 def forward(self, x: torch.Tensor):
247 qkv = self.qkv_lin(x)
形状を以下のように変更して頭部に分割します [batch_size, seq_len, n_heads, 3 * d_k]
250 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
形状ごとにクエリ、キー、値に分割 [batch_size, seq_len, n_heads, 3 * d_k]
252 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
以前のトークンの状態をキャッシュする場合
255 if get_cache().get('use_cache', False):
ステート ID を取得します。前のステートを取得したり、次のステートを保存したりするのに使います。
257 prev_state_id, next_state_id = get_cache().get('state_ids')
キャッシュがある場合
259 if prev_state_id is not None:
過去のキーと値を取得します。これらは形になります [batch_size, prev_seq_len, n_heads, d_k]
261 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
現在の埋め込みのオフセット
263 offset = k_past.shape[1]
RoPe 埋め込みを追加
266 q = self.rope(q, offset=offset)
267 k = self.rope(k, offset=offset)
過去を連結する
270 k = torch.cat([k_past, k], dim=1)
271 v = torch.cat([v_past, v], dim=1)
272 else:
RoPe 埋め込みを追加
274 q = self.rope(q)
275 k = self.rope(k)
現在の状態を保存する
278 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279 else:
キャッシュなし-RoPE 埋め込みを追加するだけ
281 q = self.rope(q)
282 k = self.rope(k)
フラッシュアテンションを使う
285 if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286 output = self.compute_flash_attention(q, k, v)
それ以外の場合は、通常の注意を払ってください
288 else:
289 output = self.compute_attention(q, k, v)
[batch_size, seq_len, n_heads, d_k] to
バッチサイズ、シーケンス番号、n_hidden `から形状を変更
292 output = output.reshape(*x.shape)
最終線形レイヤー
295 return self.output(output)
297 def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
それらを積み重ねて形を整える [batch_size, seq_len, 3, n_heads, d_k]
299 qkv = torch.stack((q, k, v), dim=2)
300 d_k = qkv.shape[-1]
301 if d_k <= 32:
302 pad = 32 - d_k
303 elif d_k <= 64:
304 pad = 64 - d_k
305 elif d_k <= 128:
306 pad = 128 - d_k
307 else:
308 raise ValueError(f'Head size {d_k} too large for flash attention')
309
310 if pad > 0:
311 qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313 output, _ = self.flash_attention(qkv, causal=True)
出力は整形しています [batch_size, seq_len, n_heads, d_k + padding]
315 output = output[:, :, :, :d_k]
316
317 return output
319 def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
アテンション計算の fp16 への自動キャストを無効にする
321 with autocast(enabled=False):
322 if q.dtype == torch.float16:
現在の dtype が fp16 の場合は fp32 に変換
324 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325 else:
bfloatにはキャストしないでください
327 attn = torch.einsum('bihk,bjhk->bijh', q, k)
スケールアテンション
330 attn = attn * self.scale
カジュアルマスクをゲット
333 mask = self._get_mask(attn)
マスクを適用
335 attn.masked_fill_(mask, self.mask_fill)
注意ソフトマックス
338 attn = self.softmax(attn)
アテンション加重値を取得
341 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343 return output
346class FFNLayer(nn.Module):
n_hidden
は埋め込みサイズ351 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355 super().__init__()
356
357 if not d_ff:
358 d_ff = n_hidden * 4
拡張リニアレイヤー
361 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
GELU アクティベーション
363 self.activation = nn.GELU()
収縮線状層
365 self.dense_h4_h = nn.Linear(d_ff, n_hidden)
x
形がある [batch_size, seq_len, n_hidden]
367 def forward(self, x: torch.Tensor):
371 x = self.dense_h_h4(x)
372 x = self.activation(x)
373 x = self.dense_h4_h(x)
374
375 return x
378class TransformerLayer(NeoXModule):
n_hidden
は埋め込みサイズn_heads
は頭の数ですis_flash_attention
フラッシュアテンションを使用するかどうかを指定しますアウトの実装にはドロップアウトは含まれていません。
383 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392 super().__init__()
注意前のレイヤー正規化
395 self.pre_ln_attn = nn.LayerNorm(n_hidden)
FFN 前のレイヤー正規化
397 self.pre_ln_ffn = nn.LayerNorm(n_hidden)
アテンションレイヤー
400 self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)
FFN レイヤー
402 self.ffn = FFNLayer(n_hidden)
x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
404 def forward(self, x: torch.Tensor):
残留接続
410 residual = x
NeoXはアテンションネットワークとフィードフォワードネットワークを並行して実行します
412 attn = self.attention(self.pre_ln_attn(x))
413 ffn = self.ffn(self.pre_ln_ffn(x))
それらと残りの接続を追加します
415 return attn + ffn + residual
チェックポイントをロードするコード
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421 with monit.section('Load transformer layer'):
アテンション出力変換
423 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
アテンションクエリ、キー、値の変換
427 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
注目される前のレイヤーノルム
431 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
FFN 2 番目のトランスフォーム
435 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
FFN ファーストトランスフォーム
439 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
FFN 前のレイヤーノルム
443 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
447class FinalNorm(NeoXModule):
n_hidden
は埋め込みサイズ452 def __init__(self, n_hidden: int = 6_144):
456 super().__init__()
457
458 self.ln = nn.LayerNorm(n_hidden)
x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
460 def forward(self, x: torch.Tensor):
464 return self.ln(x)
チェックポイントをロードするコード
466 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470 with monit.section('Load final normalization layer'):
471 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
読み出し層
475class ReadoutLayer(NeoXModule):
n_hidden
は埋め込みサイズn_vocab
ボキャブラリーの大きさです480 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485 super().__init__()
486
487 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
x
形が埋め込まれているものです [batch_size, seq_len, n_hidden]
489 def forward(self, x: torch.Tensor):
493 return self.linear(x)
チェックポイントをロードするコード
495 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499 with monit.section('Load final linear layer'):
500 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504 pre_created_layers: Dict[Any, Optional[NeoXModule]]
レイヤーはチェックポイントと同じ順序で生成されます。
None
レイヤーが使用できない場合に返されます。レイヤーインデックスをNeoXとして使用し、実装には必要のない変換レイヤーが2つあります。
n_vocab
ボキャブラリ内のトークンの数ですn_hidden
は埋め込み内のフィーチャの数ですn_layers
変圧器層の数ですn_heads
アテンション・ヘッドの数ですfilter_layers
使用するレイヤーのセットです。None の場合はすべてのレイヤーが使用されます。これは、レイヤー数の少ないモデルの小さいバージョンをテストする場合に使用しますis_clone_layers
トランスフォーマーレイヤーのクローンを作成するかどうかを指定します (少し速くなります)dtype
モデルのデータ型ですdevice
モデルのデバイスですis_llm_int8
int8 量子化を使用するかどうかを指定しますllm_int8_threshold
外れ値の特徴を分離するための閾値ですis_flash_attention
フラッシュアテンションを使用するかどうかを指定します506 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507 n_layers: int = 44, n_heads: int = 64,
508 filter_layers: Optional[Set] = None,
509 is_clone_layers: bool = True,
510 dtype: torch.dtype = torch.float,
511 device: torch.device = torch.device('cpu'),
512 is_llm_int8: bool = False,
513 llm_int8_threshold: float = 6.0,
514 is_flash_attention: bool = False
515 ):
538 if filter_layers is None:
539 filter_layers = set(range(n_layers + 3))
540
541 self.n_vocab = n_vocab
542 self.n_hidden = n_hidden
543 self.n_layers = n_layers
544 self.n_heads = n_heads
545 self.filter_layers = filter_layers
546 self.is_clone_layers = is_clone_layers
547 self.dtype = dtype
548 self.device = device
549 self.is_llm_int8 = is_llm_int8
550 self.llm_int8_threshold = llm_int8_threshold
551 self.is_flash_attention = is_flash_attention
552
553 self.pre_created_layers = dict(
554 transformer_layer=None,
555 )
557 def _prepare_layer(self, layer: NeoXModule):
566 return layer.to(self.device, self.dtype)
この関数は、チェックポイントを読み込んだ後にレイヤー変換を実装します。
現在、適用されるのは int8 量子化のみです。
layer
準備するレイヤーですis_llm_int8
int8 量子化を使用するかどうかを指定しますdevice
モデルのデバイスですllm_int8_threshold
外れ値の特徴を分離するための閾値です準備したレイヤーを返します
568 @torch.no_grad()
569 def post_load_prepare(self, layer: NeoXModule, *,
570 is_llm_int8: bool = None,
571 device: torch.device = None,
572 llm_int8_threshold: float = None,
573 ):
指定しない場合はデフォルト値を取得
591 if is_llm_int8 is None:
592 is_llm_int8 = self.is_llm_int8
593 if device is None:
594 device = self.device
595 if llm_int8_threshold is None:
596 llm_int8_threshold = self.llm_int8_threshold
int8 量子化を使用しない場合はスキップ
599 if not is_llm_int8:
600 return layer
トランスレイヤーの線形レイヤーのみを変換します
603 if not isinstance(layer, TransformerLayer):
604 return layer
make_llm_int8_linear
ユーティリティで定義されている用途。
607 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
線形レイヤーの変換
610 with monit.section('Convert to int8'):
611 layer.attention.output = make_llm_int8_linear(layer.attention.output,
612 device=device,
613 threshold=llm_int8_threshold)
614 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615 device=device,
616 threshold=llm_int8_threshold)
617 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618 device=device,
619 threshold=llm_int8_threshold)
620 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621 device=device,
622 threshold=llm_int8_threshold)
624 return layer
キャッシュされたレイヤーのコピーは、パラメーターの初期化に時間がかかるため、新しいレイヤーを初期化するよりも高速です。
name
レイヤーの名前ですcreator
レイヤーを作成する関数です作成されたレイヤーまたはキャッシュされたレイヤーのコピーを返します
626 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638 if not self.is_clone_layers:
639 return self._prepare_layer(creator())
640
641 if self.pre_created_layers[name] is None:
642 self.pre_created_layers[name] = self._prepare_layer(creator())
643
644 layer = copy.deepcopy(self.pre_created_layers[name])
645 return layer
647 def _create_transformer_layer(self):
648 return self._create_and_cache_layer(
649 'transformer_layer',
650 lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651 )
653 def _create_embedding_layer(self):
654 return Embedding(self.n_vocab, self.n_hidden)
656 def _create_final_norm_layer(self):
657 return FinalNorm(self.n_hidden)
659 def _create_readout_layer(self):
660 return ReadoutLayer(self.n_hidden, self.n_vocab)
662 @torch.no_grad()
663 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
埋め込みレイヤー
668 if 0 in self.filter_layers:
669 with monit.section('Embedding layer'):
670 layer = self._prepare_layer(self._create_embedding_layer())
671 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
トランスフォーマー層
674 for i in range(self.n_layers):
変圧器層
676 if i + 1 in self.filter_layers:
677 with monit.section(f'Transformer Layer {i}'):
678 yield self._create_transformer_layer(), \
679 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680 f'layer_{i + 2 :02d}-model_01-model_states.pt')
最終正規化レイヤー
683 if self.n_layers + 1 in self.filter_layers:
684 with monit.section('Final norm layer'):
685 layer = self._prepare_layer(self._create_final_norm_layer())
686 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
読み出し層
689 if self.n_layers + 2 in self.filter_layers:
690 with monit.section('Readout layer'):
691 layer = self._prepare_layer(self._create_readout_layer())
692 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694 for k in self.pre_created_layers.keys():
695 self.pre_created_layers[k] = None
697 @property
698 def total_layers(self):
702 return self.n_layers + 3
704 @torch.no_grad()
705 def load(self) -> Generator[NeoXModule, None, None]:
709 with monit.section("Layers"):
710 for i, (layer, files) in enumerate(self.get_layers()):
711 if files is not None:
712 layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714 layer = self.post_load_prepare(layer)
715
716 monit.progress(min(0.99, (i + 1) / self.total_layers))
717 yield layer