GPT-2 with LoRA modules

Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.

13import torch
14import torch.nn as nn
15
16from labml_nn.lora import Linear, Embedding

Feedforward Network

19class FFN(nn.Module):
  • d_model is the number of dimensions
  • d_ff is the size of the hidden dimension
  • r is the lora rank
24    def __init__(self, d_model: int, d_ff: int, r: int):
30        super().__init__()

The linear layers and the activation

33        self.linear_in = Linear(d_model, d_ff, r=r, bias=True)
34        self.linear_out = Linear(d_ff, d_model, r=r, bias=True)
35        self.act = nn.GELU()
  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
37    def forward(self, x: torch.Tensor) -> torch.Tensor:
41        x = self.linear_in(x)
42        x = self.act(x)
43        x = self.linear_out(x)
44        return x

Multi-Head Attention

47class MultiHeadAttention(nn.Module):
  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of heads
  • r is the lora rank
52    def __init__(self, d_model: int, n_heads: int, r: int):
58        super().__init__()
59        self.d_model = d_model
60        self.n_heads = n_heads
61        self.d_head = d_model // n_heads

Linear transformation for QKV

64        self.qkv_projection = Linear(d_model, d_model * 3, r=r, bias=True)

Output projection

66        self.output_projection = Linear(d_model, d_model, r=r, bias=True)
  • x is the tensor with shape [batch_size, seq_len, d_model]
68    def _split_heads(self, x: torch.Tensor):

Split last dimension to [n_heads, d_head]

73        x = x.view(x.shape[:-1] + (self.n_heads, self.d_head))

Reorder to [batch_size, head, seq_length, d_head]

75        return x.permute(0, 2, 1, 3)
  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
77    def forward(self, x: torch.Tensor) -> torch.Tensor:
81        batch_size, seq_length, _ = x.shape

Get query, key and value

84        q, k, v = self.qkv_projection(x).split(self.d_model, dim=-1)

Transform them from shape [batch_size, seq_len, d_model] to [batch_size, head, seq_length, d_head]

87        q = self._split_heads(q)
88        k = self._split_heads(k)
89        v = self._split_heads(v)

Apply causal attention

92        attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)

Transform them from shape [batch_size, head, seq_length, d_head] to [batch_size, seq_len, d_model]

95        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_length, self.d_model)

Final project

98        return self.output_projection(attn_output)

Decoder block

101class Block(nn.Module):
  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of heads
  • layer_norm_epsilon is the layer norm epsilon
  • r is the lora rank
106    def __init__(self, d_model: int, n_heads: int, layer_norm_epsilon: float, r: int):
113        super().__init__()

Attention pre-normalization layer

115        self.attn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)

Attention layer

117        self.attn = MultiHeadAttention(d_model, n_heads, r)

FFN pre-normalization layer

119        self.ffn_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)

Feed-forward network

121        self.ffn = FFN(d_model, d_model * 4, r)
  • x is the embeddings tensor with shape [batch_size, seq_len, d_model]
123    def forward(self, x: torch.Tensor) -> torch.Tensor:

Attention

128        x = x + self.attn(self.attn_norm(x))

FFN

130        x = x + self.ffn(self.ffn_norm(x))
131
132        return x

GPT2 Model

135class GPTModel(nn.Module):
  • d_model is the number of dimensions in the embeddings
  • n_heads is the number of attention heads
  • n_layers is the number of decoder layers
  • n_positions is the number of positional embeddings
  • layer_norm_epsilon is the layer norm epsilon
  • vocab_size is the vocabulary size
  • r is the lora rank
140    def __init__(self, *, d_model: int,
141                 n_heads: int, n_layers: int,
142                 n_positions: int,
143                 layer_norm_epsilon: float,
144                 vocab_size: int, r: int):
154        super().__init__()

Token and absolute positional embeddings

157        self.token_embedding = Embedding(vocab_size, d_model, r=r)
158        self.position_embedding = Embedding(n_positions, d_model, r=r)

Decoder blocks

161        self.blocks = nn.ModuleList([Block(d_model, n_heads, layer_norm_epsilon, r=r)
162                                     for _ in range(n_layers)])

Final layer norm

165        self.final_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)

Projection layer to logit space

167        self.lm_head = Linear(d_model, vocab_size, r=r, bias=False)
  • input_ids has shape [batch_size, seq_len]
169    def forward(self, input_ids: torch.Tensor):
173        batch_size, seq_len = input_ids.shape

Get token embeddings

176        token_embeddings = self.token_embedding(input_ids)

Get position ids

178        position_ids = torch.arange(seq_len, device=input_ids.device)[None, :]

Get position embeddings

180        position_embeddings = self.position_embedding(position_ids)

Add position embeddings

183        x = token_embeddings + position_embeddings

Run through transformer blocks

186        for block in self.blocks:
187            x = block(x)

Final normalization

190        x = self.final_norm(x)

Get logits from projection layer

192        return self.lm_head(x)