1import inspect
2import math
3
4import torch
5import torch.nn as nn
6from labml_nn.rwkv.configs import RWKVConfigs
7
8from labml_nn.rwkv import RWKV
9from labml_nn.rwkv import TimeMixing
10from labml import experiment
11from labml.configs import option
12from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
15class Configs(NLPAutoRegressionConfigs):
RWKV model
24 model: RWKV
25
26 rwkv: RWKVConfigs
number of warmup iterations
28 warmup_iters: int = 2000
total number of training iterations
30 max_iters: int = 600000
weight decay
32 weight_decay: float = 1e-1
Custom optimizer
34 beta1: float = 0.9
35 beta2: float = 0.95
36 optimizer = 'rwkv_optimizer'
39@option(Configs.rwkv, 'RWKV')
40def _rwkv_configs(c: Configs):
We use our configurable RWKV implementation
47 conf = RWKVConfigs()
Set the vocabulary sizes for embeddings and generating logits
49 conf.n_src_vocab = c.n_tokens
50 conf.n_tgt_vocab = c.n_tokens
51
52 return conf
55def _init_weights(module, rwkv: RWKVConfigs):
initialize Vector Parameters in TimeMixing
57 if isinstance(module, TimeMixing):
58 layer_id = module.layer_id
59 n_layer = module.n_layer
60 n_embd = module.n_embd
61 attn_sz = n_embd
62
63 with torch.no_grad():
64 ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
65 ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
66 ddd = torch.ones(1, 1, n_embd)
67 for i in range(n_embd):
68 ddd[0, 0, i] = i / n_embd
69
70 decay_speed = torch.ones(attn_sz)
71 for h in range(attn_sz):
72 decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
73 module.time_decay = nn.Parameter(decay_speed)
74
75 zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
76 module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
77 module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
78 module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
79 module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
Create RWKV model and initialize weights
82@option(Configs.model)
83def _model(c: Configs):
87 m = RWKV(c.rwkv).to(c.device)
Apply custom weight initialization
90 m.apply(_init_weights, c.rwkv)
91
92 return m
95@option(NLPAutoRegressionConfigs.optimizer)
96def _configure_optimizers(c: NLPAutoRegressionConfigs):
start with all of the candidate parameters
98 param_dict = {pn: p for pn, p in c.model.named_parameters()}
filter out those that do not require grad
100 param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
103 decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
104 nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
105 optim_groups = [
106 {'params': decay_params, 'weight_decay': c.weight_decay},
107 {'params': nodecay_params, 'weight_decay': 0.0}
108 ]
109 num_decay_params = sum(p.numel() for p in decay_params)
110 num_nodecay_params = sum(p.numel() for p in nodecay_params)
111 print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
112 print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
Create AdamW optimizer and use the fused version if it is available
114 fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
115 use_fused = fused_available and c.device_type == 'cuda'
116 extra_args = dict(fused=True) if use_fused else dict()
117 optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
118 print(f"using fused AdamW: {use_fused}")
119
120 return optimizer
123def main():
Create experiment
125 experiment.create(name="RWKV")
Create configs
127 conf = Configs()
128 print(conf.model)
Override configurations
130 experiment.configs(conf, {
Use character level tokenizer
132 'tokenizer': 'character',
Prompt separator is blank
134 'prompt_separator': '',
Starting prompt for sampling
136 'prompt': 'It is ',
Use Tiny Shakespeare dataset
138 'text': 'tiny_shakespeare',
Use a context size of
141 'seq_len': 128,
Train for epochs
143 'epochs': 32,
Batch size
145 'batch_size': 128,
Switch between training and validation for times per epoch
148 'inner_iterations': 10,
149
150 'rwkv.block_size': 1024,
model
152 'rwkv.n_layer': 12,
153 'rwkv.n_heads': 12,
154 'rwkv.n_embd': 768
155 })
156
157 print(conf.model)
Set models for saving and loading
159 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
162 with experiment.start():
Run training
164 conf.run()
168if __name__ == '__main__':
169 main()