1from typing import List, Dict
2
3import torch
4from torch import nn
5
6from labml_nn.neox.model import TransformerLayer, NeoXModule
9class FineTuner:
10 def __init__(self, layers: List[NeoXModule]):
11 self.layers = layers
13 def get_trainable_params(self) -> Dict[str, nn.Parameter]:
14 params = {}
15 for i, layer in enumerate(self.layers):
16 params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
17
18 return params
20 def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
21 raise NotImplementedError
23 def set_trainable_params(self):
24 for layer in self.layers:
Set requires_grad
to False
for the entire layer.
26 layer.requires_grad_(False)
28 for p in self.get_trainable_params().values():
29 p.requires_grad_(True)
31 def state_dict(self):
32 return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
34 def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
35 params = self.get_trainable_params()
36 for n, p in params.items():
37 p.data[:] = state_dict[n].to(p.data.device)
38
39 for n in state_dict.keys():
40 assert n in params, n
43class FineTuneBiases(FineTuner):
44 def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
45 params = {}
46
47 if isinstance(layer, TransformerLayer):
No need to train the mlp bias because we are adding it with attention output
49 params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
50 params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
51 params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
52 else:
53 pass
54
55 return params