Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.
13import torch
14import torch.nn as nn
15
16from labml_nn.lora import Linear, Embedding
19class FFN(nn.Module):
d_model
is the number of dimensions d_ff
is the size of the hidden dimension r
is the lora rank24 def __init__(self, d_model: int, d_ff: int, r: int):
30 super().__init__()
The linear layers and the activation
33 self.linear_in = Linear(d_model, d_ff, r=r, bias=True)
34 self.linear_out = Linear(d_ff, d_model, r=r, bias=True)
35 self.act = nn.GELU()
x
is the embeddings tensor with shape [batch_size, seq_len, d_model]
37 def forward(self, x: torch.Tensor) -> torch.Tensor:
41 x = self.linear_in(x)
42 x = self.act(x)
43 x = self.linear_out(x)
44 return x
47class MultiHeadAttention(nn.Module):
d_model
is the number of dimensions in the embeddings n_heads
is the number of heads r
is the lora rank52 def __init__(self, d_model: int, n_heads: int, r: int):
58 super().__init__()
59 self.d_model = d_model
60 self.n_heads = n_heads
61 self.d_head = d_model // n_heads
Linear transformation for QKV
64 self.qkv_projection = Linear(d_model, d_model * 3, r=r, bias=True)
Output projection
66 self.output_projection = Linear(d_model, d_model, r=r, bias=True)
x
is the tensor with shape [batch_size, seq_len, d_model]
68 def _split_heads(self, x: torch.Tensor):
Split last dimension to [n_heads, d_head]
73 x = x.view(x.shape[:-1] + (self.n_heads, self.d_head))
Reorder to [batch_size, head, seq_length, d_head]
75 return x.permute(0, 2, 1, 3)
x
is the embeddings tensor with shape [batch_size, seq_len, d_model]
77 def forward(self, x: torch.Tensor) -> torch.Tensor:
81 batch_size, seq_length, _ = x.shape
Get query, key and value
84 q, k, v = self.qkv_projection(x).split(self.d_model, dim=-1)
Transform them from shape [batch_size, seq_len, d_model]
to [batch_size, head, seq_length, d_head]
87 q = self._split_heads(q)
88 k = self._split_heads(k)
89 v = self._split_heads(v)
Apply causal attention
92 attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
Transform them from shape [batch_size, head, seq_length, d_head]
to [batch_size, seq_len, d_model]
95 attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, self.d_model)
Final project
98 return self.output_projection(attn_output)
101class Block(nn.Module):
d_model
is the number of dimensions in the embeddings n_heads
is the number of heads layer_norm_epsilon
is the layer norm epsilon r
is the lora rank106 def __init__(self, d_model: int, n_heads: int, layer_norm_epsilon: float, r: int):
113 super().__init__()
Attention pre-normalization layer
115 self.attn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
Attention layer
117 self.attn = MultiHeadAttention(d_model, n_heads, r)
FFN pre-normalization layer
119 self.ffn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
Feed-forward network
121 self.ffn = FFN(d_model, d_model * 4, r)
x
is the embeddings tensor with shape [batch_size, seq_len, d_model]
123 def forward(self, x: torch.Tensor) -> torch.Tensor:
Attention
128 x = x + self.attn(self.attn_norm(x))
FFN
130 x = x + self.ffn(self.ffn_norm(x))
131
132 return x
135class GPTModel(nn.Module):
d_model
is the number of dimensions in the embeddings n_heads
is the number of attention heads n_layers
is the number of decoder layers n_positions
is the number of positional embeddings layer_norm_epsilon
is the layer norm epsilon vocab_size
is the vocabulary size r
is the lora rank140 def __init__(self, *, d_model: int,
141 n_heads: int, n_layers: int,
142 n_positions: int,
143 layer_norm_epsilon: float,
144 vocab_size: int, r: int):
154 super().__init__()
Token and absolute positional embeddings
157 self.token_embedding = Embedding(vocab_size, d_model, r=r)
158 self.position_embedding = Embedding(n_positions, d_model, r=r)
Decoder blocks
161 self.blocks = nn.ModuleList([Block(d_model, n_heads, layer_norm_epsilon, r=r)
162 for _ in range(n_layers)])
Final layer norm
165 self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
Projection layer to logit space
167 self.lm_head = Linear(d_model, vocab_size, r=r, bias=False)
input_ids
has shape [batch_size, seq_len]
169 def forward(self, input_ids: torch.Tensor):
173 batch_size, seq_len = input_ids.shape
Get token embeddings
176 token_embeddings = self.token_embedding(input_ids)
Get position ids
178 position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]
Get position embeddings
180 position_embeddings = self.position_embedding(position_ids)
Add position embeddings
183 x = token_embeddings + position_embeddings
Run through transformer blocks
186 for block in self.blocks:
187 x = block(x)
Final normalization
190 x = self.final_norm(x)
Get logits from projection layer
192 return self.lm_head(x)