1import inspect
2import math
3
4import torch
5import torch.nn as nn
6from labml_nn.rwkv.configs import RWKVConfigs
7
8from labml_nn.rwkv import RWKV
9from labml_nn.rwkv import TimeMixing
10from labml import experiment
11from labml.configs import option
12from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs

Configurations

This inherits from NLPAutoRegressionConfigs

15class Configs(NLPAutoRegressionConfigs):

RWKV model

24    model: RWKV
25
26    rwkv: RWKVConfigs

number of warmup iterations

28    warmup_iters: int = 2000

total number of training iterations

30    max_iters: int = 600000

weight decay

32    weight_decay: float = 1e-1

Custom optimizer

34    beta1: float = 0.9
35    beta2: float = 0.95
36    optimizer = 'rwkv_optimizer'

RWKV configurations

39@option(Configs.rwkv, 'RWKV')
40def _rwkv_configs(c: Configs):
47    conf = RWKVConfigs()

Set the vocabulary sizes for embeddings and generating logits

49    conf.n_src_vocab = c.n_tokens
50    conf.n_tgt_vocab = c.n_tokens
51
52    return conf
55def _init_weights(module, rwkv: RWKVConfigs):

initialize Vector Parameters in TimeMixing

57    if isinstance(module, TimeMixing):
58        layer_id = module.layer_id
59        n_layer = module.n_layer
60        n_embd = module.n_embd
61        attn_sz = n_embd
62
63        with torch.no_grad():
64            ratio_0_to_1 = layer_id / (n_layer - 1)  # 0 to 1
65            ratio_1_to_almost0 = 1.0 - (layer_id / n_layer)  # 1 to ~0
66            ddd = torch.ones(1, 1, n_embd)
67            for i in range(n_embd):
68                ddd[0, 0, i] = i / n_embd
69
70            decay_speed = torch.ones(attn_sz)
71            for h in range(attn_sz):
72                decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
73            module.time_decay = nn.Parameter(decay_speed)
74
75            zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
76            module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
77            module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
78            module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
79            module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))

Create RWKV model and initialize weights

82@option(Configs.model)
83def _model(c: Configs):
87    m = RWKV(c.rwkv).to(c.device)

Apply custom weight initialization

90    m.apply(_init_weights, c.rwkv)
91
92    return m
95@option(NLPAutoRegressionConfigs.optimizer)
96def _configure_optimizers(c: NLPAutoRegressionConfigs):

start with all of the candidate parameters

98    param_dict = {pn: p for pn, p in c.model.named_parameters()}

filter out those that do not require grad

100    param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}

create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.

103    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
104    nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
105    optim_groups = [
106        {'params': decay_params, 'weight_decay': c.weight_decay},
107        {'params': nodecay_params, 'weight_decay': 0.0}
108    ]
109    num_decay_params = sum(p.numel() for p in decay_params)
110    num_nodecay_params = sum(p.numel() for p in nodecay_params)
111    print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
112    print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")

Create AdamW optimizer and use the fused version if it is available

114    fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
115    use_fused = fused_available and c.device_type == 'cuda'
116    extra_args = dict(fused=True) if use_fused else dict()
117    optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
118    print(f"using fused AdamW: {use_fused}")
119
120    return optimizer
123def main():

Create experiment

125    experiment.create(name="RWKV")

Create configs

127    conf = Configs()
128    print(conf.model)

Override configurations

130    experiment.configs(conf, {

Use character level tokenizer

132        'tokenizer': 'character',

Prompt separator is blank

134        'prompt_separator': '',

Starting prompt for sampling

136        'prompt': 'It is ',

Use Tiny Shakespeare dataset

138        'text': 'tiny_shakespeare',

Use a context size of

141        'seq_len': 128,

Train for epochs

143        'epochs': 32,

Batch size

145        'batch_size': 128,

Switch between training and validation for times per epoch

148        'inner_iterations': 10,
149
150        'rwkv.block_size': 1024,

model

152        'rwkv.n_layer': 12,
153        'rwkv.n_heads': 12,
154        'rwkv.n_embd': 768
155    })
156
157    print(conf.model)

Set models for saving and loading

159    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

162    with experiment.start():

Run training

164        conf.run()

168if __name__ == '__main__':
169    main()