Here is the code for layers of GPT-NeoX model and the code to load 20B checkpoint.
The method load_state
in the layers load the checkpoints of that layer. The checkpoint loading helpers are on checkpoint.py
16import copy
17import math
18from typing import Dict, Optional, Set, Callable, Any, Generator, Tuple
19
20import torch
21from torch import nn
22from torch.cuda.amp import autocast
23
24from labml import monit, logger
25from labml.logger import Text
26from labml_nn.neox import checkpoint
27from labml_nn.neox.utils.cache import get_cache
30class NeoXModule(nn.Module):
31 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
32 pass
35class Embedding(NeoXModule):
n_vocab
is the size of the vocabulary n_hidden
is the size of the embeddings42 def __init__(self, n_vocab: int = 50_432, n_hidden: int = 6_144):
47 super().__init__()
48
49 self.emb = nn.Embedding(n_vocab, n_hidden)
x
are the token ids of shape [batch_size, seq_len]
51 def forward(self, x: torch.Tensor):
55 return self.emb(x)
Code to load the checkpoint
57 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
61 with monit.section('Load embedding layer'):
62 checkpoint.merge_params_dim_0(self.emb.weight, 'word_embeddings.weight', p1, p2)
GPT-NeoX uses rotary positional embeddings (RoPE).
WE have annotated implementation of RoPE here with more notes the theory.
65class RoPE(nn.Module):
d_rope
is the number of features for RoPE embeddings base
is the base for , which defaults to 75 def __init__(self, d_rope: int, base: float = 10_000.):
80 super().__init__()
To store for the features
83 self.theta = None
Cache and
85 self.cos_cached = None
86 self.sin_cached = None
Base for
89 self.base = base
Number of features for RoPE
91 self.d_rope = d_rope
93 @staticmethod
94 def rotate_half(x: torch.Tensor):
100 x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
101 return torch.cat((-x2, x1), dim=-1)
x
has shape [..., seq, n_heads, d_k]
offset
is the starting position of x
. This is when we have cached the keys and queries of previous positions103 def forward(self, x: torch.Tensor, offset: int = 0):
Get the actual sequence length
111 seq_len = x.shape[-3] + offset
Initialize
114 if self.theta is None:
116 theta = 1.0 / (self.base ** (torch.arange(0, self.d_rope, 2).float() / self.d_rope))
117 self.theta = theta.to(x.device).to(x.dtype)
Initialize and cache
120 if (
121 self.cos_cached is None or
122 seq_len > self.cos_cached.shape[1] or
123 self.cos_cached.device != x.device or
124 self.cos_cached.dtype != x.dtype
125 ):
Get position indexes
127 seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)
129 idx_theta = torch.einsum("s,d->sd", seq_idx, self.theta)
133 idx_theta2 = torch.cat((idx_theta, idx_theta), dim=-1).to(x.device)
Calculate and in fp32
136 with autocast(enabled=False):
137 idx_theta2 = idx_theta2.float()
Add head dimension
139 self.cos_cached = idx_theta2.cos()[:, None, :]
140 self.sin_cached = idx_theta2.sin()[:, None, :]
Cache them
143 self.cos_cached = self.cos_cached.to(x.dtype)
144 self.sin_cached = self.sin_cached.to(x.dtype)
Split the features. We apply RoPE to only d_rope
features
147 x_rope, x_pass = x[..., :self.d_rope], x[..., self.d_rope:]
Get the sin and cos values from the cache
150 cos, sin = self.cos_cached[offset: seq_len], self.sin_cached[offset: seq_len]
162 x_rope = (x_rope * cos) + (self.rotate_half(x_rope) * sin)
Concatenate with features that didn't get RoPE embeddings
165 return torch.cat((x_rope, x_pass), dim=-1)
168class AttentionLayer(nn.Module):
n_hidden
the number of features in embeddings n_heads
the number of attention heads rope_percentage
percentage of features to add RoPE embeddings mask_fill
masking fill value for attention matrix is_flash_attention
specifies whether to use FlashAttention173 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, rope_percentage: float = 0.25,
174 mask_fill: float = -10_000.0, *, is_flash_attention: bool = False):
183 super().__init__()
184
185 self.n_heads = n_heads
186 self.mask_fill = mask_fill
Linear layer for query, key and value
189 self.qkv_lin = nn.Linear(n_hidden, n_hidden * 3)
Final linear layer
191 self.output = nn.Linear(n_hidden, n_hidden)
Number of features per head
194 d_k = n_hidden // n_heads
RoPE embedding module
196 self.rope = RoPE(int(d_k * rope_percentage))
Attention scaling factor
199 self.scale = 1 / math.sqrt(d_k)
To cache causal mask
202 self.causal_mask = None
Attention softmax module
205 self.softmax = nn.Softmax(dim=-2)
208 if is_flash_attention:
209 try:
210 from flash_attn.flash_attention import FlashAttention
211 self.flash_attention = FlashAttention()
212 except ImportError:
213 logger.log('Install flash attention github.com/HazyResearch/flash-attention. '
214 'Falling back to normal attention', Text.warning)
215 self.flash_attention = None
216 else:
217 self.flash_attention = None
219 def _get_mask(self, attn: torch.Tensor):
Query and key lengths
227 nq, nk = attn.shape[1:3]
Create mask
230 if (
231 self.causal_mask is None or
232 self.causal_mask.shape[0] != nq or
233 self.causal_mask.shape[1] != nk or
234 self.causal_mask.device != attn.device
235 ):
236 self.causal_mask = torch.triu(attn.new_ones([nq, nk], dtype=torch.bool), 1 + nk - nq)
Return from cache
239 return self.causal_mask[None, :, :, None]
x
has shape [batch_size, seq_len, n_hidden]
241 def forward(self, x: torch.Tensor):
Get query, key and value embeddings (all concatenated). The last dimension size will change from n_hidden -> 3 x n_hidden
247 qkv = self.qkv_lin(x)
Split into heads by changing the shape to [batch_size, seq_len, n_heads, 3 * d_k]
250 qkv = qkv.view(*qkv.shape[:-1], self.n_heads, -1)
Split into query, key and value each of shape [batch_size, seq_len, n_heads, 3 * d_k]
252 q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1)
If we are caching the states of previous tokens
255 if get_cache().get('use_cache', False):
Get the state id's. We use to retrieve previous states and store the next states
257 prev_state_id, next_state_id = get_cache().get('state_ids')
If there's cache
259 if prev_state_id is not None:
Get the past keys and values. These will have shape [batch_size, prev_seq_len, n_heads, d_k]
261 k_past, v_past = get_cache().pop(f'attn_kv_{prev_state_id}')
Offset of the current embeddings
263 offset = k_past.shape[1]
Add RoPE embeddings
266 q = self.rope(q, offset=offset)
267 k = self.rope(k, offset=offset)
Concatenate the past
270 k = torch.cat([k_past, k], dim=1)
271 v = torch.cat([v_past, v], dim=1)
272 else:
Add RoPE embeddings
274 q = self.rope(q)
275 k = self.rope(k)
Save the current state
278 get_cache().push(f'attn_kv_{next_state_id}', (k, v))
279 else:
No cache - simply add RoPE embeddings
281 q = self.rope(q)
282 k = self.rope(k)
Use flash attention
285 if self.flash_attention is not None and q.shape[1] == k.shape[1] and q.shape[-1] <= 128:
286 output = self.compute_flash_attention(q, k, v)
Otherwise, use normal attention
288 else:
289 output = self.compute_attention(q, k, v)
Reshape from [batch_size, seq_len, n_heads, d_k] to
batch_size, seq_len, n_hidden`
292 output = output.reshape(*x.shape)
Final linear layer
295 return self.output(output)
297 def compute_flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
Stack them into shape [batch_size, seq_len, 3, n_heads, d_k]
299 qkv = torch.stack((q, k, v), dim=2)
300 d_k = qkv.shape[-1]
301 if d_k <= 32:
302 pad = 32 - d_k
303 elif d_k <= 64:
304 pad = 64 - d_k
305 elif d_k <= 128:
306 pad = 128 - d_k
307 else:
308 raise ValueError(f'Head size {d_k} too large for flash attention')
309
310 if pad > 0:
311 qkv = torch.cat((qkv, qkv.new_zeros(*qkv.shape[:-1], pad)), dim=-1)
312
313 output, _ = self.flash_attention(qkv, causal=True)
The output is of shape [batch_size, seq_len, n_heads, d_k + padding]
315 output = output[:, :, :, :d_k]
316
317 return output
319 def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
Disable auto-casting to fp16 for attention computation
321 with autocast(enabled=False):
322 if q.dtype == torch.float16:
Convert to fp32 if the current dtype is fp16
324 attn = torch.einsum('bihk,bjhk->bijh', q.float(), k.float())
325 else:
Do not cast for bfloat
327 attn = torch.einsum('bihk,bjhk->bijh', q, k)
Scale attention
330 attn = attn * self.scale
Get causal mask
333 mask = self._get_mask(attn)
Apply mask
335 attn.masked_fill_(mask, self.mask_fill)
Attention softmax
338 attn = self.softmax(attn)
Get attention weighted values
341 output = torch.einsum('bijh,bjhk->bihk', attn.to(v.dtype), v)
342
343 return output
346class FFNLayer(nn.Module):
n_hidden
is the embedding size351 def __init__(self, n_hidden: int = 6_144, d_ff: int = 0):
355 super().__init__()
356
357 if not d_ff:
358 d_ff = n_hidden * 4
Expansion linear layer
361 self.dense_h_h4 = nn.Linear(n_hidden, d_ff)
GELU activation
363 self.activation = nn.GELU()
Contraction linear layer
365 self.dense_h4_h = nn.Linear(d_ff, n_hidden)
x
has shape [batch_size, seq_len, n_hidden]
367 def forward(self, x: torch.Tensor):
371 x = self.dense_h_h4(x)
372 x = self.activation(x)
373 x = self.dense_h4_h(x)
374
375 return x
378class TransformerLayer(NeoXModule):
n_hidden
is the embedding size n_heads
is the number of heads is_flash_attention
specifies whether to use FlashAttentionOut implementation doesn't include dropout.
383 def __init__(self, n_hidden: int = 6_144, n_heads: int = 64, *, is_flash_attention: bool = False):
392 super().__init__()
Layer normalization before attention
395 self.pre_ln_attn = nn.LayerNorm(n_hidden)
Layer normalization before FFN
397 self.pre_ln_ffn = nn.LayerNorm(n_hidden)
Attention layer
400 self.attention = AttentionLayer(n_hidden, n_heads, is_flash_attention=is_flash_attention)
FFN layer
402 self.ffn = FFNLayer(n_hidden)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
404 def forward(self, x: torch.Tensor):
Residual connection
410 residual = x
NeoX runs attention and feedforward network in parallel
412 attn = self.attention(self.pre_ln_attn(x))
413 ffn = self.ffn(self.pre_ln_ffn(x))
Add them and the residual connection
415 return attn + ffn + residual
Code to load the checkpoint
417 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
421 with monit.section('Load transformer layer'):
Attention output transform
423 checkpoint.merge_params_sum(self.attention.output.bias, 'attention.dense.bias', p1, p2)
424 checkpoint.merge_params_dim_1(self.attention.output.weight, 'attention.dense.weight', p1, p2)
Attention query, key and value transform
427 checkpoint.merge_params_dim_0(self.attention.qkv_lin.bias, 'attention.query_key_value.bias', p1, p2)
428 checkpoint.merge_params_dim_0(self.attention.qkv_lin.weight, 'attention.query_key_value.weight', p1, p2)
Layer norm before attention
431 checkpoint.merge_params_duplicate(self.pre_ln_attn.bias, 'input_layernorm.bias', p1, p2)
432 checkpoint.merge_params_duplicate(self.pre_ln_attn.weight, 'input_layernorm.weight', p1, p2)
FFN second transform
435 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.bias, 'mlp.dense_h_to_4h.bias', p1, p2)
436 checkpoint.merge_params_dim_0(self.ffn.dense_h_h4.weight, 'mlp.dense_h_to_4h.weight', p1, p2)
FFN first transform
439 checkpoint.merge_params_sum(self.ffn.dense_h4_h.bias, 'mlp.dense_4h_to_h.bias', p1, p2)
440 checkpoint.merge_params_dim_1(self.ffn.dense_h4_h.weight, 'mlp.dense_4h_to_h.weight', p1, p2)
Layer norm before FFN
443 checkpoint.merge_params_duplicate(self.pre_ln_ffn.bias, 'post_attention_layernorm.bias', p1, p2)
444 checkpoint.merge_params_duplicate(self.pre_ln_ffn.weight, 'post_attention_layernorm.weight', p1, p2)
447class FinalNorm(NeoXModule):
n_hidden
is the embedding size452 def __init__(self, n_hidden: int = 6_144):
456 super().__init__()
457
458 self.ln = nn.LayerNorm(n_hidden)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
460 def forward(self, x: torch.Tensor):
464 return self.ln(x)
Code to load the checkpoint
466 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
470 with monit.section('Load final normalization layer'):
471 checkpoint.merge_params_duplicate(self.ln.bias, 'norm.bias', p1, p2)
472 checkpoint.merge_params_duplicate(self.ln.weight, 'norm.weight', p1, p2)
Readout layer
475class ReadoutLayer(NeoXModule):
n_hidden
is the embedding size n_vocab
is the size of the vocabulary480 def __init__(self, n_hidden: int = 6_144, n_vocab: int = 50_432):
485 super().__init__()
486
487 self.linear = nn.Linear(n_hidden, n_vocab, bias=False)
x
are the embeddings of shape [batch_size, seq_len, n_hidden]
489 def forward(self, x: torch.Tensor):
493 return self.linear(x)
Code to load the checkpoint
495 def load_state(self, p1: Dict[str, torch.Tensor], p2: Dict[str, torch.Tensor]):
499 with monit.section('Load final linear layer'):
500 checkpoint.merge_params_dim_0(self.linear.weight, 'final_linear.weight', p1, p2)
503class LayerGenerator:
504 pre_created_layers: Dict[Any, Optional[NeoXModule]]
The layers are generated in the same order as checkpoints.
It gives None
when a layer is not available; we use the layer indices as NeoX and there are two transformation layers we don't need in our implementation.
n_vocab
is the number of tokens in the vocabulary n_hidden
is the number of features in the embeddings n_layers
is the number of transformer layers n_heads
is the number of attention heads filter_layers
are the set of layers to be used. All layers will be used if None. This is used to test smaller versions of the model with fewer layers is_clone_layers
specifies whether to clone the transformer layers (a bit faster) dtype
is the data type of the model device
is the device of the model is_llm_int8
specifies whether to use int8 quantization llm_int8_threshold
is the threshold used to separate outlier features is_flash_attention
specifies whether to use FlashAttention506 def __init__(self, *, n_vocab: int = 50_432, n_hidden: int = 6_144,
507 n_layers: int = 44, n_heads: int = 64,
508 filter_layers: Optional[Set] = None,
509 is_clone_layers: bool = True,
510 dtype: torch.dtype = torch.float,
511 device: torch.device = torch.device('cpu'),
512 is_llm_int8: bool = False,
513 llm_int8_threshold: float = 6.0,
514 is_flash_attention: bool = False
515 ):
538 if filter_layers is None:
539 filter_layers = set(range(n_layers + 3))
540
541 self.n_vocab = n_vocab
542 self.n_hidden = n_hidden
543 self.n_layers = n_layers
544 self.n_heads = n_heads
545 self.filter_layers = filter_layers
546 self.is_clone_layers = is_clone_layers
547 self.dtype = dtype
548 self.device = device
549 self.is_llm_int8 = is_llm_int8
550 self.llm_int8_threshold = llm_int8_threshold
551 self.is_flash_attention = is_flash_attention
552
553 self.pre_created_layers = dict(
554 transformer_layer=None,
555 )
We move the layer to the device and convert it to the correct data type
layer
is the layer to prepare Returns the prepared layer
557 def _prepare_layer(self, layer: NeoXModule):
566 return layer.to(self.device, self.dtype)
This function implements layer transformations after loading the checkpoint.
Currently, it only applies the int8 quantization.
layer
is the layer to prepare is_llm_int8
specifies whether to use int8 quantization device
is the device of the model llm_int8_threshold
is the threshold used to separate outlier features Returns the prepared layer
568 @torch.no_grad()
569 def post_load_prepare(self, layer: NeoXModule, *,
570 is_llm_int8: bool = None,
571 device: torch.device = None,
572 llm_int8_threshold: float = None,
573 ):
Get default values if not specified
591 if is_llm_int8 is None:
592 is_llm_int8 = self.is_llm_int8
593 if device is None:
594 device = self.device
595 if llm_int8_threshold is None:
596 llm_int8_threshold = self.llm_int8_threshold
Skip if not using int8 quantization
599 if not is_llm_int8:
600 return layer
Only convert the linear layers in the transformer layers
603 if not isinstance(layer, TransformerLayer):
604 return layer
607 from labml_nn.neox.utils.llm_int8 import make_llm_int8_linear
Convert the linear layers
610 with monit.section('Convert to int8'):
611 layer.attention.output = make_llm_int8_linear(layer.attention.output,
612 device=device,
613 threshold=llm_int8_threshold)
614 layer.attention.qkv_lin = make_llm_int8_linear(layer.attention.qkv_lin,
615 device=device,
616 threshold=llm_int8_threshold)
617 layer.ffn.dense_h_h4 = make_llm_int8_linear(layer.ffn.dense_h_h4,
618 device=device,
619 threshold=llm_int8_threshold)
620 layer.ffn.dense_h4_h = make_llm_int8_linear(layer.ffn.dense_h4_h,
621 device=device,
622 threshold=llm_int8_threshold)
624 return layer
Copying cached layers is faster than initializing new layers because it takes time to initialize parameters.
name
is the name of the layer creator
is the function to create the layer Returns the created layer or a copy of the cached layer
626 def _create_and_cache_layer(self, name: str, creator: Callable[[], NeoXModule]):
638 if not self.is_clone_layers:
639 return self._prepare_layer(creator())
640
641 if self.pre_created_layers[name] is None:
642 self.pre_created_layers[name] = self._prepare_layer(creator())
643
644 layer = copy.deepcopy(self.pre_created_layers[name])
645 return layer
647 def _create_transformer_layer(self):
648 return self._create_and_cache_layer(
649 'transformer_layer',
650 lambda: TransformerLayer(self.n_hidden, self.n_heads, is_flash_attention=self.is_flash_attention)
651 )
653 def _create_embedding_layer(self):
654 return Embedding(self.n_vocab, self.n_hidden)
656 def _create_final_norm_layer(self):
657 return FinalNorm(self.n_hidden)
659 def _create_readout_layer(self):
660 return ReadoutLayer(self.n_hidden, self.n_vocab)
662 @torch.no_grad()
663 def get_layers(self) -> Generator[Tuple[NeoXModule, Tuple[str, str]], None, None]:
Embedding layer
668 if 0 in self.filter_layers:
669 with monit.section('Embedding layer'):
670 layer = self._prepare_layer(self._create_embedding_layer())
671 yield layer, ('layer_00-model_00-model_states.pt', 'layer_00-model_01-model_states.pt')
Transformer layers
674 for i in range(self.n_layers):
Transformer layer
676 if i + 1 in self.filter_layers:
677 with monit.section(f'Transformer Layer {i}'):
678 yield self._create_transformer_layer(), \
679 (f'layer_{i + 2 :02d}-model_00-model_states.pt',
680 f'layer_{i + 2 :02d}-model_01-model_states.pt')
Final normalization layer
683 if self.n_layers + 1 in self.filter_layers:
684 with monit.section('Final norm layer'):
685 layer = self._prepare_layer(self._create_final_norm_layer())
686 yield layer, ('layer_47-model_00-model_states.pt', 'layer_47-model_01-model_states.pt')
Readout layer
689 if self.n_layers + 2 in self.filter_layers:
690 with monit.section('Readout layer'):
691 layer = self._prepare_layer(self._create_readout_layer())
692 yield layer, ('layer_48-model_00-model_states.pt', 'layer_48-model_01-model_states.pt')
693
694 for k in self.pre_created_layers.keys():
695 self.pre_created_layers[k] = None
697 @property
698 def total_layers(self):
702 return self.n_layers + 3
704 @torch.no_grad()
705 def load(self) -> Generator[NeoXModule, None, None]:
709 with monit.section("Layers"):
710 for i, (layer, files) in enumerate(self.get_layers()):
711 if files is not None:
712 layer.load_state(*checkpoint.load_checkpoint_files(files))
713
714 layer = self.post_load_prepare(layer)
715
716 monit.progress(min(0.99, (i + 1) / self.total_layers))
717 yield layer