Low-Rank Adaptation (LoRA)

This is an implementation of Low-Rank Adaptation (LoRA) in PyTorch.

Low-Rank Adaptation (LoRA) freezes pre-trained model weights and injects trainable rank decomposition matrices into each layer of the transformer. This makes it possible to efficiently fine-tune large langauge models by reducing trainable parameters by a large factor.

Here's the training code for training a GPT2 model with LoRA on Tiny Shakespeare dataset.

24import torch
25import torch.nn as nn

LoRA Linear Layer

LoRA linear layer adds a low-rank decomposition to the pre-trained weight matrix () of the linear layer.

, where , , and the rank .

All parameters are frozen except and .

is initialized to be zero at the beginning of the training.

They multiple by where is a hyper-parameter. Once is tuned it can be kept the same when varying .

28class Linear(nn.Module):
  • in_features is the number of input features of the linear layer
  • out_features is the number of output features of the linear layer
  • bias is a flag indicating if there is a bias parameter
  • r is the rank of the decomposition
  • alpha is the scaling factor
49    def __init__(self, in_features: int, out_features: int, bias: bool,
50                 r: int, alpha: int = None):
58        super().__init__()

Set is not provided. i.e. make the scaling factor .

61        if alpha is None:
62            alpha = r

The pre-trained weight

65        self.weight = nn.Parameter(torch.empty((out_features, in_features)))

Freeze it

67        self.weight.requires_grad = False
68
69        if bias:

Bias parameter (also frozen)

71            self.bias = nn.Parameter(torch.empty(out_features))
72            self.bias.requires_grad = False
73        else:

No bias parameter

75            self.bias = None

scaling factor

78        self.scaling = alpha / r

Matrix

80        self.lora_a = nn.Parameter(torch.empty((r, in_features)))

Matrix , we keep and transposed

82        self.lora_b = nn.Parameter(torch.empty((out_features, r)))
83
84        with torch.no_grad():

Initialize similar to a weight matrix in a normal linear layer

86            nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)

Initialize to so that is at initialization

88            nn.init.zeros_(self.lora_b)
90    def forward(self, x: torch.Tensor):

Compute

92        result = nn.functional.linear(x, self.weight, bias=self.bias)

Add

95        result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling

98        return result

LoRA Embedding Layer

Similar to LoRA linear layer this adds a low-rank decomposition to the pre-trained embedding weights matrix ().

101class Embedding(nn.Module):
  • num_embeddings is the number of embeddings
  • embedding_dim is the number embedding dimensions
  • r is the rank of the decomposition
  • alpha is the scaling factor
111    def __init__(self, num_embeddings: int, embedding_dim: int,
112                 r: int, alpha: int = None):
120        super().__init__()

Set is not provided. i.e. make the scaling factor .

123        if alpha is None:
124            alpha = r

The pre-trained embedding weights (frozen)

127        self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
128        self.weight.requires_grad = False

scaling factor

131        self.scaling = alpha / r

Matrix

133        self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))

Matrix

135        self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
136
137        with torch.no_grad():

Initialize with a normal distribution

139            nn.init.normal_(self.lora_a)

Initialize to so that is at initialization

141            nn.init.zeros_(self.lora_b)
143    def forward(self, x: torch.Tensor):

Compute the embeddings

145        result = nn.functional.embedding(x, self.weight)

Add

148        result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling

151        return result