This is an implementation of Low-Rank Adaptation (LoRA) in PyTorch.
Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large langauge models by reducing trainable parameters by a large factor.
Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.
24import torch
25import torch.nn as nn
LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix () of the linear layer.
, where , , and the rank .
All parameters are frozen except and .
is initialized to be zero at the beginning of the training.
They multiple by where is a hyper-parameter. Once is tuned it can be kept the same when varying .
28class Linear(nn.Module):
in_features
is the number of input features of the linear layer out_features
is the number of output features of the linear layer bias
is a flag indicating if there is a bias parameter r
is the rank of the decomposition alpha
is the scaling factor 49 def __init__(self, in_features: int, out_features: int, bias: bool,
50 r: int, alpha: int = None):
58 super().__init__()
Set is not provided. i.e. make the scaling factor .
61 if alpha is None:
62 alpha = r
The pre-trained weight
65 self.weight = nn.Parameter(torch.empty((out_features, in_features)))
Freeze it
67 self.weight.requires_grad = False
68
69 if bias:
Bias parameter (also frozen)
71 self.bias = nn.Parameter(torch.empty(out_features))
72 self.bias.requires_grad = False
73 else:
No bias parameter
75 self.bias = None
scaling factor
78 self.scaling = alpha / r
Matrix
80 self.lora_a = nn.Parameter(torch.empty((r, in_features)))
Matrix , we keep and transposed
82 self.lora_b = nn.Parameter(torch.empty((out_features, r)))
83
84 with torch.no_grad():
Initialize similar to a weight matrix in a normal linear layer
86 nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
Initialize to so that is at initialization
88 nn.init.zeros_(self.lora_b)
90 def forward(self, x: torch.Tensor):
Compute
92 result = nn.functional.linear(x, self.weight, bias=self.bias)
Add
95 result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
98 return result
Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained embedding weights matrix ().
101class Embedding(nn.Module):
num_embeddings
is the number of embeddings embedding_dim
is the number embedding dimensions r
is the rank of the decomposition alpha
is the scaling factor 111 def __init__(self, num_embeddings: int, embedding_dim: int,
112 r: int, alpha: int = None):
120 super().__init__()
Set is not provided. i.e. make the scaling factor .
123 if alpha is None:
124 alpha = r
The pre-trained embedding weights (frozen)
127 self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
128 self.weight.requires_grad = False
scaling factor
131 self.scaling = alpha / r
Matrix
133 self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))
Matrix
135 self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
136
137 with torch.no_grad():
Initialize with a normal distribution
139 nn.init.normal_(self.lora_a)
Initialize to so that is at initialization
141 nn.init.zeros_(self.lora_b)
143 def forward(self, x: torch.Tensor):
Compute the embeddings
145 result = nn.functional.embedding(x, self.weight)
Add
148 result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling
151 return result