GPT ネオックスモデル

これは、GPT-Neoxモデルのレイヤー用のコードと20Bのチェックポイントをロードするコードです。

load_state レイヤー内のメソッドは、そのレイヤーのチェックポイントをロードします。チェックポイントロードヘルパーがオンになっています checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass

埋め込みレイヤー

これは、チェックポイントをロードするコードを含む標準の埋め込みレイヤーです。

35class Embedding(NeoXModule):
  • n_vocab ボキャブラリーの大きさです
  • n_hidden は埋め込みのサイズです
42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x 形状のトークンIDです [batch_size, seq_len]
51    def forward(self, x: torch.Tensor):
55        return self.emb(x)

チェックポイントをロードするコード

57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

ロータリーポジショナルエンベディング

GPT-Neoxは回転式ポジショナルエンベディング(RoPE)を使用しています。

ここでは、RoPE の実装に注釈を付けて、理論に関する注釈を付けました。

65class RoPE(nn.Module):
  • d_rope RoPE 埋め込みの機能の数です
  • base がの基底で、デフォルトは
75    def __init__(self, d_rope: int, base: float = 10_000.):
80        super().__init__()

機能用に保存するには

83        self.theta = None

キャッシュと

85        self.cos_cached = None
86        self.sin_cached = None

のベース

89        self.base = base

RoPE の機能の数

91        self.d_rope = d_rope

フィーチャをローテーションしてください

93    @staticmethod
94    def rotate_half(x: torch.Tensor):
100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)
  • x 形がある [..., seq, n_heads, d_k]
  • offset x の開始位置です。これは、以前のポジションのキーとクエリをキャッシュしたときです
103    def forward(self, x: torch.Tensor, offset: int = 0):

実際のシーケンス長を取得

111        seq_len = x.shape[-3] + offset

[初期化]

114        if self.theta is None:

116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)

初期化とキャッシュ

120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):

位置インデックスを取得

127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

行が次のようになるように連結します

133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

計算して fp32 で

136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()

頭部寸法を追加

139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]

それらをキャッシュする

143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)

機能を分割してください。RoPE d_rope は機能にのみ適用されます

147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

キャッシュから sin と cos の値を取得

150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

ロープ埋め込み

にとって

162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

RoPe 埋め込みに対応していなかった機能との連携

165        return torch.cat((x_rope, x_pass), dim=-1)

アテンションレイヤー

168class AttentionLayer(nn.Module):
173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fill

クエリ、キー、値の線形レイヤー

189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

最終線形レイヤー

191        self.output = nn.Linear(n_hidden, n_hidden)

ヘッドあたりの機能数

194        d_k = n_hidden // n_heads

RoPE 埋め込みモジュール

196        self.rope = RoPE(int(d_k * rope_percentage))

アテンションスケーリングファクター

199        self.scale = 1 / math.sqrt(d_k)

因果マスクをキャッシュするには

202        self.causal_mask = None

アテンションソフトマックスモジュール

205        self.softmax = nn.Softmax(dim=-2)
208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None

因果マスクの計算

219    def _get_mask(self, attn: torch.Tensor):

クエリとキーの長さ

227        nq, nk = attn.shape[1:3]

マスク作成

230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

キャッシュから戻る

239        return self.causal_mask[None, :, :, None]
  • x 形がある [batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):

クエリ、キー、値の埋め込み (すべて連結) を取得します。最後のディメンションサイズが n_hidden から変更されます

-> 3 x n_hidden
247        qkv = self.qkv_lin(x)

形状を以下のように変更して頭部に分割します [batch_size, seq_len, n_heads, 3 * d_k]

250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

形状ごとにクエリ、キー、値に分割 [batch_size, seq_len, n_heads, 3 * d_k]

252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

以前のトークンの状態をキャッシュする場合

255        if get_cache().get('use_cache', False):

ステート ID を取得します。前のステートを取得したり、次のステートを保存したりするのに使います。

257            prev_state_id, next_state_id = get_cache().get('state_ids')

キャッシュがある場合

259            if prev_state_id is not None:

過去のキーと値を取得します。これらは形になります [batch_size, prev_seq_len, n_heads, d_k]

261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

現在の埋め込みのオフセット

263                offset = k_past.shape[1]

RoPe 埋め込みを追加

266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)

過去を連結する

270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:

RoPe 埋め込みを追加

274                q = self.rope(q)
275                k = self.rope(k)

現在の状態を保存する

278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:

キャッシュなし-RoPE 埋め込みを追加するだけ

281            q = self.rope(q)
282            k = self.rope(k)

フラッシュアテンションを使う

285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)

それ以外の場合は、通常の注意を払ってください

288        else:
289            output = self.compute_attention(q, k, v)

[batch_size, seq_len, n_heads, d_k] to バッチサイズ、シーケンス番号、n_hidden `から形状を変更

292        output = output.reshape(*x.shape)

最終線形レイヤー

295        return self.output(output)
297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

それらを積み重ねて形を整える [batch_size, seq_len, 3, n_heads, d_k]

299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)

出力は整形しています [batch_size, seq_len, n_heads, d_k + padding]

315        output = output[:, :, :, :d_k]
316
317        return output
319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

アテンション計算の fp16 への自動キャストを無効にする

321        with autocast(enabled=False):
322            if q.dtype == torch.float16:

現在の dtype が fp16 の場合は fp32 に変換

324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:

bfloatにはキャストしないでください

327                attn = torch.einsum('bihk,bjhk->bijh', q, k)

スケールアテンション

330            attn = attn * self.scale

カジュアルマスクをゲット

333            mask = self._get_mask(attn)

マスクを適用

335            attn.masked_fill_(mask, self.mask_fill)

注意ソフトマックス

338            attn = self.softmax(attn)

アテンション加重値を取得

341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output

フィードフォワードネットワーク

346class FFNLayer(nn.Module):
  • n_hidden は埋め込みサイズ
351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4

拡張リニアレイヤー

361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELU アクティベーション

363        self.activation = nn.GELU()

収縮線状層

365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x 形がある [batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):
371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x

変圧器層

378class TransformerLayer(NeoXModule):

アウトの実装にはドロップアウトは含まれていません

383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392        super().__init__()

注意前のレイヤー正規化

395        self.pre_ln_attn = nn.LayerNorm(n_hidden)

FFN 前のレイヤー正規化

397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

アテンションレイヤー

400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)

FFN レイヤー

402        self.ffn = FFNLayer(n_hidden)
  • x 形が埋め込まれているものです [batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):

残留接続

410        residual = x

NeoXはアテンションネットワークとフィードフォワードネットワークを並行して実行します

412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))

それらと残りの接続を追加します

415        return attn + ffn + residual

チェックポイントをロードするコード

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load transformer layer'):

アテンション出力変換

423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

アテンションクエリ、キー、値の変換

427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

注目される前のレイヤーノルム

431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFN 2 番目のトランスフォーム

435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFN ファーストトランスフォーム

439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

FFN 前のレイヤーノルム

443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

最終正規化レイヤー

447class FinalNorm(NeoXModule):
  • n_hidden は埋め込みサイズ
452    def __init__(self, n_hidden: int = 6_144):
456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)
  • x 形が埋め込まれているものです [batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):
464        return self.ln(x)

チェックポイントをロードするコード

466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

読み出し層

475class ReadoutLayer(NeoXModule):
  • n_hidden は埋め込みサイズ
  • n_vocab ボキャブラリーの大きさです
480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x 形が埋め込まれているものです [batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):
493        return self.linear(x)

チェックポイントをロードするコード

495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]

レイヤーを作成するためのジェネレーター

レイヤーはチェックポイントと同じ順序で生成されます。

None レイヤーが使用できない場合に返されます。レイヤーインデックスをNeoXとして使用し、実装には必要のない変換レイヤーが2つあります。

  • n_vocab ボキャブラリ内のトークンの数です
  • n_hidden は埋め込み内のフィーチャの数です
  • n_layers 変圧器層の数です
  • n_heads アテンション・ヘッドの数です
  • filter_layers 使用するレイヤーのセットです。None の場合はすべてのレイヤーが使用されます。これは、レイヤー数の少ないモデルの小さいバージョンをテストする場合に使用します
  • is_clone_layers トランスフォーマーレイヤーのクローンを作成するかどうかを指定します (少し速くなります)
  • dtype モデルのデータ型です
  • device モデルのデバイスです
  • is_llm_int8 int8 量子化を使用するかどうかを指定します
  • llm_int8_threshold 外れ値の特徴を分離するための閾値です
  • is_flash_attention フラッシュアテンションを使用するかどうかを指定します
506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):
538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )

レイヤーを使用できるように準備します

レイヤーをデバイスに移動し、正しいデータ型に変換します。

  • layer 準備するレイヤーです
  • 準備したレイヤーを返します

557    def _prepare_layer(self, layer: NeoXModule):
566        return layer.to(self.device, self.dtype)

チェックポイントをロードした後のレイヤー変換

この関数は、チェックポイントを読み込んだ後にレイヤー変換を実装します。

現在、適用されるのは int8 量子化のみです。

  • layer 準備するレイヤーです
  • is_llm_int8 int8 量子化を使用するかどうかを指定します
  • device モデルのデバイスです
  • llm_int8_threshold 外れ値の特徴を分離するための閾値です
  • 準備したレイヤーを返します

568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):

指定しない場合はデフォルト値を取得

591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_threshold

int8 量子化を使用しない場合はスキップ

599        if not is_llm_int8:
600            return layer

トランスレイヤーの線形レイヤーのみを変換します

603        if not isinstance(layer, TransformerLayer):
604            return layer
607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

線形レイヤーの変換

610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)

624        return layer

レイヤーを作成してキャッシュします

キャッシュされたレイヤーのコピーは、パラメーターの初期化に時間がかかるため、新しいレイヤーを初期化するよりも高速です。

  • name レイヤーの名前です
  • creator レイヤーを作成する関数です
  • 作成されたレイヤーまたはキャッシュされたレイヤーのコピーを返します

626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer
647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )
653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)
656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)
659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)

レイヤーを取得するためのジェネレーター

662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

埋め込みレイヤー

668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

トランスフォーマー層

674        for i in range(self.n_layers):

変圧器層

676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

最終正規化レイヤー

683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

読み出し層

689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None

レイヤーの総数を返します

697    @property
698    def total_layers(self):
702        return self.n_layers + 3

レイヤーをロードするジェネレーター

704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:
709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer