GPT-NeoX Model

Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.

The method load_state in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py

16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32        pass

Embedding layer

This is a standard embeddings layer with code to load the checkpoint.

35class Embedding(NeoXModule):
  • n_vocab is the size of the vocabulary
  • n_hidden is the size of the embeddings
42    def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47        super().__init__()
48
49        self.emb = nn.Embedding(n_vocab, n_hidden)
  • x are the token ids of shape [batch_size, seq_len]
51    def forward(self, x: torch.Tensor):
55        return self.emb(x)

Code to load the checkpoint

57    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61        with monit.section('Load embedding layer'):
62            checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)

Rotary Positional Embeddings

GPT-NeoX uses rotary positional embeddings (RoPE).

WE have annotated implementation of RoPE here with more notes the theory.

65class RoPE(nn.Module):
  • d_rope is the number of features for RoPE embeddings
  • base is the base for , which defaults to
75    def __init__(self, d_rope: int, base: float = 10_000.):
80        super().__init__()

To store for the features

83        self.theta = None

Cache and

85        self.cos_cached = None
86        self.sin_cached = None

Base for

89        self.base = base

Number of features for RoPE

91        self.d_rope = d_rope

Rotate the features

93    @staticmethod
94    def rotate_half(x: torch.Tensor):
100        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101        return torch.cat((-x2, x1), dim=-1)
  • x has shape [..., seq, n_heads, d_k]
  • offset is the starting position of x . This is when we have cached the keys and queries of previous positions
103    def forward(self, x: torch.Tensor, offset: int = 0):

Get the actual sequence length

111        seq_len = x.shape[-3] + offset

Initialize

114        if self.theta is None:

116            theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117            self.theta = theta.to(x.device).to(x.dtype)

Initialize and cache

120        if (
121                self.cos_cached is None or
122                seq_len > self.cos_cached.shape[1] or
123                self.cos_cached.device != x.device or
124                self.cos_cached.dtype != x.dtype
125        ):

Get position indexes

127            seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

129            idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)

Concatenate so that for row we have

133            idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)

Calculate and in fp32

136            with autocast(enabled=False):
137                idx_theta2 = idx_theta2.float()

Add head dimension

139                self.cos_cached = idx_theta2.cos()[:, None, :]
140                self.sin_cached = idx_theta2.sin()[:, None, :]

Cache them

143            self.cos_cached = self.cos_cached.to(x.dtype)
144            self.sin_cached = self.sin_cached.to(x.dtype)

Split the features. We apply RoPE to only d_rope features

147        x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]

Get the sin and cos values from the cache

150        cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]

RoPE embeddings

for

162        x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)

Concatenate with features that didn't get RoPE embeddings

165        return torch.cat((x_rope, x_pass), dim=-1)

Attention layer

168class AttentionLayer(nn.Module):
  • n_hidden the number of features in embeddings
  • n_heads the number of attention heads
  • rope_percentage percentage of features to add RoPE embeddings
  • mask_fill masking fill value for attention matrix
  • is_flash_attention specifies whether to use FlashAttention
173    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174                 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183        super().__init__()
184
185        self.n_heads = n_heads
186        self.mask_fill = mask_fill

Linear layer for query, key and value

189        self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)

Final linear layer

191        self.output = nn.Linear(n_hidden, n_hidden)

Number of features per head

194        d_k = n_hidden // n_heads

RoPE embedding module

196        self.rope = RoPE(int(d_k * rope_percentage))

Attention scaling factor

199        self.scale = 1 / math.sqrt(d_k)

To cache causal mask

202        self.causal_mask = None

Attention softmax module

205        self.softmax = nn.Softmax(dim=-2)
208        if is_flash_attention:
209            try:
210                from flash_attn.flash_attention import FlashAttention
211                self.flash_attention = FlashAttention()
212            except ImportError:
213                logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214                           'Falling back to normal attention', Text.warning)
215                self.flash_attention = None
216        else:
217            self.flash_attention = None

Calculate the causal mask

219    def _get_mask(self, attn: torch.Tensor):

Query and key lengths

227        nq, nk = attn.shape[1:3]

Create mask

230        if (
231                self.causal_mask is None or
232                self.causal_mask.shape[0] != nq or
233                self.causal_mask.shape[1] != nk or
234                self.causal_mask.device != attn.device
235        ):
236            self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)

Return from cache

239        return self.causal_mask[None, :, :, None]
  • x has shape [batch_size, seq_len, n_hidden]
241    def forward(self, x: torch.Tensor):

Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden

247        qkv = self.qkv_lin(x)

Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]

250        qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)

Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]

252        q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)

If we are caching the states of previous tokens

255        if get_cache().get('use_cache', False):

Get the state id's. We use to retrieve previous states and store the next states

257            prev_state_id, next_state_id = get_cache().get('state_ids')

If there's cache

259            if prev_state_id is not None:

Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]

261                k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')

Offset of the current embeddings

263                offset = k_past.shape[1]

Add RoPE embeddings

266                q = self.rope(q, offset=offset)
267                k = self.rope(k, offset=offset)

Concatenate the past

270                k = torch.cat([k_past, k], dim=1)
271                v = torch.cat([v_past, v], dim=1)
272            else:

Add RoPE embeddings

274                q = self.rope(q)
275                k = self.rope(k)

Save the current state

278            get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279        else:

No cache - simply add RoPE embeddings

281            q = self.rope(q)
282            k = self.rope(k)

Use flash attention

285        if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286            output = self.compute_flash_attention(q, k, v)

Otherwise, use normal attention

288        else:
289            output = self.compute_attention(q, k, v)

Reshape from [batch_size, seq_len, n_heads, d_k] to batch_size, seq_len, n_hidden`

292        output = output.reshape(*x.shape)

Final linear layer

295        return self.output(output)
297    def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

Stack them into shape [batch_size, seq_len, 3, n_heads, d_k]

299        qkv = torch.stack((q, k, v), dim=2)
300        d_k = qkv.shape[-1]
301        if d_k <= 32:
302            pad = 32 - d_k
303        elif d_k <= 64:
304            pad = 64 - d_k
305        elif d_k <= 128:
306            pad = 128 - d_k
307        else:
308            raise ValueError(f'Head size {d_k} too large for flash attention')
309
310        if pad > 0:
311            qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313        output, _ = self.flash_attention(qkv, causal=True)

The output is of shape [batch_size, seq_len, n_heads, d_k + padding]

315        output = output[:, :, :, :d_k]
316
317        return output
319    def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):

Disable auto-casting to fp16 for attention computation

321        with autocast(enabled=False):
322            if q.dtype == torch.float16:

Convert to fp32 if the current dtype is fp16

324                attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325            else:

Do not cast for bfloat

327                attn = torch.einsum('bihk,bjhk->bijh', q, k)

Scale attention

330            attn = attn * self.scale

Get causal mask

333            mask = self._get_mask(attn)

Apply mask

335            attn.masked_fill_(mask, self.mask_fill)

Attention softmax

338            attn = self.softmax(attn)

Get attention weighted values

341        output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343        return output

Feedforward Network

346class FFNLayer(nn.Module):
  • n_hidden is the embedding size
351    def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355        super().__init__()
356
357        if not d_ff:
358            d_ff = n_hidden * 4

Expansion linear layer

361        self.dense_h_h4 = nn.Linear(n_hidden, d_ff)

GELU activation

363        self.activation = nn.GELU()

Contraction linear layer

365        self.dense_h4_h = nn.Linear(d_ff, n_hidden)
  • x has shape [batch_size, seq_len, n_hidden]
367    def forward(self, x: torch.Tensor):
371        x = self.dense_h_h4(x)
372        x = self.activation(x)
373        x = self.dense_h4_h(x)
374
375        return x

Transformer Layer

378class TransformerLayer(NeoXModule):
  • n_hidden is the embedding size
  • n_heads is the number of heads
  • is_flash_attention specifies whether to use FlashAttention

Out implementation doesn't include dropout.

383    def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392        super().__init__()

Layer normalization before attention

395        self.pre_ln_attn = nn.LayerNorm(n_hidden)

Layer normalization before FFN

397        self.pre_ln_ffn = nn.LayerNorm(n_hidden)

Attention layer

400        self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)

FFN layer

402        self.ffn = FFNLayer(n_hidden)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
404    def forward(self, x: torch.Tensor):

Residual connection

410        residual = x

NeoX runs attention and feedforward network in parallel

412        attn = self.attention(self.pre_ln_attn(x))
413        ffn = self.ffn(self.pre_ln_ffn(x))

Add them and the residual connection

415        return attn + ffn + residual

Code to load the checkpoint

417    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421        with monit.section('Load transformer layer'):

Attention output transform

423            checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424            checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)

Attention query, key and value transform

427            checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428            checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)

Layer norm before attention

431            checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432            checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)

FFN second transform

435            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436            checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)

FFN first transform

439            checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440            checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)

Layer norm before FFN

443            checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444            checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)

Final normalization layer

447class FinalNorm(NeoXModule):
  • n_hidden is the embedding size
452    def __init__(self, n_hidden: int = 6_144):
456        super().__init__()
457
458        self.ln = nn.LayerNorm(n_hidden)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
460    def forward(self, x: torch.Tensor):
464        return self.ln(x)

Code to load the checkpoint

466    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470        with monit.section('Load final normalization layer'):
471            checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472            checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)

Readout layer

475class ReadoutLayer(NeoXModule):
  • n_hidden is the embedding size
  • n_vocab is the size of the vocabulary
480    def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485        super().__init__()
486
487        self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
  • x are the embeddings of shape [batch_size, seq_len, n_hidden]
489    def forward(self, x: torch.Tensor):
493        return self.linear(x)

Code to load the checkpoint

495    def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499        with monit.section('Load final linear layer'):
500            checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504    pre_created_layers: Dict[Any, Optional[NeoXModule]]

Generator to create layers

The layers are generated in the same order as checkpoints.

It gives None when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.

  • n_vocab is the number of tokens in the vocabulary
  • n_hidden is the number of features in the embeddings
  • n_layers is the number of transformer layers
  • n_heads is the number of attention heads
  • filter_layers are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layers
  • is_clone_layers specifies whether to clone the transformer layers (a bit faster)
  • dtype is the data type of the model
  • device is the device of the model
  • is_llm_int8 specifies whether to use int8 quantization
  • llm_int8_threshold is the threshold used to separate outlier features
  • is_flash_attention specifies whether to use FlashAttention
506    def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507                 n_layers: int = 44, n_heads: int = 64,
508                 filter_layers: Optional[Set] = None,
509                 is_clone_layers: bool = True,
510                 dtype: torch.dtype = torch.float,
511                 device: torch.device = torch.device('cpu'),
512                 is_llm_int8: bool = False,
513                 llm_int8_threshold: float = 6.0,
514                 is_flash_attention: bool = False
515                 ):
538        if filter_layers is None:
539            filter_layers = set(range(n_layers + 3))
540
541        self.n_vocab = n_vocab
542        self.n_hidden = n_hidden
543        self.n_layers = n_layers
544        self.n_heads = n_heads
545        self.filter_layers = filter_layers
546        self.is_clone_layers = is_clone_layers
547        self.dtype = dtype
548        self.device = device
549        self.is_llm_int8 = is_llm_int8
550        self.llm_int8_threshold = llm_int8_threshold
551        self.is_flash_attention = is_flash_attention
552
553        self.pre_created_layers = dict(
554            transformer_layer=None,
555        )

Prepares the layer for usage

We move the layer to the device and convert it to the correct data type

  • layer is the layer to prepare
  • Returns the prepared layer

557    def _prepare_layer(self, layer: NeoXModule):
566        return layer.to(self.device, self.dtype)

Layer transformations after loading the checkpoint

This function implements layer transformations after loading the checkpoint.

Currently, it only applies the int8 quantization.

  • layer is the layer to prepare
  • is_llm_int8 specifies whether to use int8 quantization
  • device is the device of the model
  • llm_int8_threshold is the threshold used to separate outlier features
  • Returns the prepared layer

568    @torch.no_grad()
569    def post_load_prepare(self, layer: NeoXModule, *,
570                          is_llm_int8: bool = None,
571                          device: torch.device = None,
572                          llm_int8_threshold: float = None,
573                          ):

Get default values if not specified

591        if is_llm_int8 is None:
592            is_llm_int8 = self.is_llm_int8
593        if device is None:
594            device = self.device
595        if llm_int8_threshold is None:
596            llm_int8_threshold = self.llm_int8_threshold

Skip if not using int8 quantization

599        if not is_llm_int8:
600            return layer

Only convert the linear layers in the transformer layers

603        if not isinstance(layer, TransformerLayer):
604            return layer

Use make_llm_int8_linear defined in utilities.

607        from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear

Convert the linear layers

610        with monit.section('Convert to int8'):
611            layer.attention.output = make_llm_int8_linear(layer.attention.output,
612                                                          device=device,
613                                                          threshold=llm_int8_threshold)
614            layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615                                                           device=device,
616                                                           threshold=llm_int8_threshold)
617            layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618                                                        device=device,
619                                                        threshold=llm_int8_threshold)
620            layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621                                                        device=device,
622                                                        threshold=llm_int8_threshold)

624        return layer

Creates and caches a layer

Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.

  • name is the name of the layer
  • creator is the function to create the layer
  • Returns the created layer or a copy of the cached layer

626    def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638        if not self.is_clone_layers:
639            return self._prepare_layer(creator())
640
641        if self.pre_created_layers[name] is None:
642            self.pre_created_layers[name] = self._prepare_layer(creator())
643
644        layer = copy.deepcopy(self.pre_created_layers[name])
645        return layer
647    def _create_transformer_layer(self):
648        return self._create_and_cache_layer(
649            'transformer_layer',
650            lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651        )
653    def _create_embedding_layer(self):
654        return Embedding(self.n_vocab, self.n_hidden)
656    def _create_final_norm_layer(self):
657        return FinalNorm(self.n_hidden)
659    def _create_readout_layer(self):
660        return ReadoutLayer(self.n_hidden, self.n_vocab)

Generator to get layers

662    @torch.no_grad()
663    def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:

Embedding layer

668        if 0 in self.filter_layers:
669            with monit.section('Embedding layer'):
670                layer = self._prepare_layer(self._create_embedding_layer())
671            yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')

Transformer layers

674        for i in range(self.n_layers):

Transformer layer

676            if i + 1 in self.filter_layers:
677                with monit.section(f'Transformer Layer {i}'):
678                    yield self._create_transformer_layer(), \
679                          (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680                           f'layer_{i + 2 :02d}-model_01-model_states.pt')

Final normalization layer

683        if self.n_layers + 1 in self.filter_layers:
684            with monit.section('Final norm layer'):
685                layer = self._prepare_layer(self._create_final_norm_layer())
686            yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')

Readout layer

689        if self.n_layers + 2 in self.filter_layers:
690            with monit.section('Readout layer'):
691                layer = self._prepare_layer(self._create_readout_layer())
692            yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694        for k in self.pre_created_layers.keys():
695            self.pre_created_layers[k] = None

Returns the total number of layers

697    @property
698    def total_layers(self):
702        return self.n_layers + 3

Generator to load layers

704    @torch.no_grad()
705    def load(self) -> Generator[NeoXModule, None, None]:
709        with monit.section("Layers"):
710            for i, (layer, files) in enumerate(self.get_layers()):
711                if files is not None:
712                    layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714                layer = self.post_load_prepare(layer)
715
716                monit.progress(min(0.99, (i + 1) / self.total_layers))
717                yield layer