1from typing import List, Dict
2
3import torch
4from torch import nn
5
6from labml_nn.neox.model import TransformerLayer, NeoXModule
9class FineTuner:
10    def __init__(self, layers: List[NeoXModule]):
11        self.layers = layers
13    def get_trainable_params(self) -> Dict[str, nn.Parameter]:
14        params = {}
15        for i, layer in enumerate(self.layers):
16            params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
17
18        return params
20    def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
21        raise NotImplementedError
23    def set_trainable_params(self):
24        for layer in self.layers:

Set requires_grad to False for the entire layer.

26            layer.requires_grad_(False)

28            for p in self.get_trainable_params().values():
29                p.requires_grad_(True)
31    def state_dict(self):
32        return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
34    def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
35        params = self.get_trainable_params()
36        for n, p in params.items():
37            p.data[:] = state_dict[n].to(p.data.device)
38
39        for n in state_dict.keys():
40            assert n in params, n
43class FineTuneBiases(FineTuner):
44    def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
45        params = {}
46
47        if isinstance(layer, TransformerLayer):

No need to train the mlp bias because we are adding it with attention output

49            params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
50            params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
51            params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
52        else:
53            pass
54
55        return params