Receptance Weighted Key Value (RWKV)

This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.

Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation

22import torch
23import torch.nn as nn
24from torch.nn import functional as F
25
26from labml_helpers.module import Module
27
28PREV_X_TIME = 0
29NUM_STATE = 1
30DEN_STATE = 2
31MAX_STATE = 3
32PREV_X_CHANNEL = 4

Layer normalization with bias

35class LayerNorm(Module):
40    def __init__(self, ndim, bias):
41        super().__init__()
42        self.weight = nn.Parameter(torch.ones(ndim))
43        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
45    def forward(self, input):
46        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

L2 loss wrapper

ref

49class L2Wrap(torch.autograd.Function):
56    @staticmethod
57    def forward(ctx, loss, y):
58        ctx.save_for_backward(y)
59        return loss
60
61    @staticmethod
62    def backward(ctx, grad_output):
63        y = ctx.saved_tensors[0]

to encourage the logits to be close to 0

65        factor = 1e-4 / (y.shape[0] * y.shape[1])
66        maxx, ids = torch.max(y, -1, keepdim=True)
67        gy = torch.zeros_like(y)
68        gy.scatter_(-1, ids, maxx * factor)
69        return grad_output, gy

Channel Mixing

72class ChannelMixing(Module):
77    def __init__(self, config, layer_id):
78        super().__init__()
79        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

token shifting

81        self.layer_id = layer_id
82
83        n_embd = config.n_embd
84        intermediate_size = (
85            config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
86        )

Learnable Matrix

89        self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
90        self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
91        self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)

Learnable Vector

94        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
95        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))

x = (Batch,Time,Channel)

97    def forward(self, x, state=None):
101        if state is not None:
102            prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
103            state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
104        else:
105            prev_x = self.time_shift(x)

108        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
109        receptance = self.receptance_proj(receptance)

112        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
113        key = self.key_proj(key)

116        value = self.value_proj(torch.square(torch.relu(key)))

119        out = F.sigmoid(receptance) * value
120        return out, state

Time Mixing

123class TimeMixing(Module):
128    def __init__(self, config, layer_id):
129        super().__init__()
130        self.config = config
131        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
132        self.layer_id = layer_id
133
134        n_embd = config.n_embd
135        attn_sz = n_embd

learnable matrix

138        self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
139        self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
140        self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
141        self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)

learnable vector

144        self.time_decay = nn.Parameter(torch.empty(attn_sz))
145        self.time_first = nn.Parameter(torch.empty(attn_sz))
146        self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
147        self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
148        self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))

x = (Batch,Time,Channel)

150    def forward(self, x, state=None):
154        if state is not None:
155            prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
156            state[self.layer_id, :, [PREV_X_TIME], :] = x
157        else:
158            prev_x = self.time_shift(x)

161        receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
162        receptance = self.receptance_proj(receptance)

165        key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
166        key = self.key_proj(key)

169        value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
170        value = self.value_proj(value)

WKV calculation

173        _, seq_length, _ = key.size()
174        output = torch.zeros_like(key)
175
176        if state is None:
177            num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
178            den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
179            max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
180        else:
181            num_state = state[self.layer_id, :, NUM_STATE, :]
182            den_state = state[self.layer_id, :, DEN_STATE, :]
183            max_state = state[self.layer_id, :, MAX_STATE, :]
184
185        time_decay = -torch.exp(self.time_decay)
186
187        for current_index in range(seq_length):
188            current_key = key[:, current_index].float()
189            current_value = value[:, current_index]

192            max_for_output = torch.maximum(max_state, current_key + self.time_first)
193            e1 = torch.exp(max_state - max_for_output)
194            e2 = torch.exp(current_key + self.time_first - max_for_output)
195            numerator = e1 * num_state + e2 * current_value
196            denominator = e1 * den_state + e2
197            output[:, current_index] = (numerator / denominator).to(output.dtype)

Update state for next iteration

200            max_for_state = torch.maximum(max_state + time_decay, current_key)
201            e1 = torch.exp(max_state + time_decay - max_for_state)
202            e2 = torch.exp(current_key - max_for_state)
203            num_state = e1 * num_state + e2 * current_value
204            den_state = e1 * den_state + e2
205            max_state = max_for_state

update states

208        state[self.layer_id, :, NUM_STATE, :] = num_state
209        state[self.layer_id, :, DEN_STATE, :] = den_state
210        state[self.layer_id, :, MAX_STATE, :] = max_state
211        wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
212                                       state=state)

215        rwkv = F.sigmoid(receptance) * wkv
216        rwkv = self.output_proj(rwkv)
217
218        return rwkv, state

RWKV block element

221class Block(Module):
226    def __init__(self, config, layer_id):
227        super().__init__()
228        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
229        self.attn = TimeMixing(config, layer_id)
230        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
231        self.ffn = ChannelMixing(config, layer_id)
233    def forward(self, x, state=None):

time mixing

237        residual = x
238        x, state = self.attn(self.ln_1(x), state=state)
239        x = x + residual

channel mixing

242        residual = x
243        x, state = self.ffn(self.ln_2(x), state=state)
244        x = x + residual
245        return x, state

RWKV

248class RWKV(Module):
252    def __init__(self, config, lr_init=0.0008):
253        super().__init__()
254        assert config.vocab_size is not None
255        assert config.block_size is not None
256        self.config = config
257        self.lr_init = lr_init  ## used to initialize embedding parameters
258        self.n_layer = config.n_layer
259        self.n_embd = config.n_embd

Initiate model layers

262        self.rwkv = nn.ModuleDict(dict(
263            wte=nn.Embedding(config.vocab_size, config.n_embd),
264            ln_p=LayerNorm(config.n_embd, bias=config.bias),
265            h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
266            ln_f=LayerNorm(config.n_embd, bias=config.bias),
267        ))

Output linear layer

270        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
272    def forward(self, idx, targets=None, state=None, return_state=False):
273        b, t = idx.size()
274        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

Embedding Layer

277        x = self.rwkv.wte(idx)

Layer Norm

280        x = self.rwkv.ln_p(x)

RWKV Blocks

283        for block_idx, block in enumerate(self.rwkv.h):
284            x, state = block(x, state)
285        x = self.rwkv.ln_f(x)

Logit Layer and loss Function (for training)

288        if targets is not None:

if we are given some desired targets also calculate the loss

290            logits = self.lm_head(x)
291            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
292            if self.training:
293                loss = L2Wrap.apply(loss, logits)
294        else:

inference-time mini-optimization: only forward the lm_head on the very last position

296            logits = self.lm_head(x[:, [-1], :])  # note: using list [-1] to preserve the time dim
297            loss = None

Return Logits and loss

300        if return_state:
301            return logits, loss, state
302        else:
303            return logits, loss