This is a tutorial/implementation of RWKV from paper RWKV: Reinventing RNNs for the Transformer Era in PyTorch.
Full definition of a RWKV Language Model, all of it in this single file. References: 1) the official RWKV PyTorch implementation released by Bo Peng 2) huggingface/transformers PyTorch implementation
22import torch
23import torch.nn as nn
24from torch.nn import functional as F
25
26from labml_helpers.module import Module
27
28PREV_X_TIME = 0
29NUM_STATE = 1
30DEN_STATE = 2
31MAX_STATE = 3
32PREV_X_CHANNEL = 4
35class LayerNorm(Module):
40 def __init__(self, ndim, bias):
41 super().__init__()
42 self.weight = nn.Parameter(torch.ones(ndim))
43 self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
45 def forward(self, input):
46 return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
56 @staticmethod
57 def forward(ctx, loss, y):
58 ctx.save_for_backward(y)
59 return loss
60
61 @staticmethod
62 def backward(ctx, grad_output):
63 y = ctx.saved_tensors[0]
to encourage the logits to be close to 0
65 factor = 1e-4 / (y.shape[0] * y.shape[1])
66 maxx, ids = torch.max(y, -1, keepdim=True)
67 gy = torch.zeros_like(y)
68 gy.scatter_(-1, ids, maxx * factor)
69 return grad_output, gy
72class ChannelMixing(Module):
77 def __init__(self, config, layer_id):
78 super().__init__()
79 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
token shifting
81 self.layer_id = layer_id
82
83 n_embd = config.n_embd
84 intermediate_size = (
85 config.intermediate_size if config.intermediate_size is not None else 4 * n_embd
86 )
Learnable Matrix
89 self.key_proj = nn.Linear(n_embd, intermediate_size, bias=False)
90 self.value_proj = nn.Linear(intermediate_size, n_embd, bias=False)
91 self.receptance_proj = nn.Linear(n_embd, n_embd, bias=False)
Learnable Vector
94 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
95 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
97 def forward(self, x, state=None):
101 if state is not None:
102 prev_x = state[self.layer_id, :, [PREV_X_CHANNEL], :]
103 state[self.layer_id, :, [PREV_X_CHANNEL], :] = x
104 else:
105 prev_x = self.time_shift(x)
108 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
109 receptance = self.receptance_proj(receptance)
112 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
113 key = self.key_proj(key)
116 value = self.value_proj(torch.square(torch.relu(key)))
119 out = F.sigmoid(receptance) * value
120 return out, state
123class TimeMixing(Module):
128 def __init__(self, config, layer_id):
129 super().__init__()
130 self.config = config
131 self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
132 self.layer_id = layer_id
133
134 n_embd = config.n_embd
135 attn_sz = n_embd
learnable matrix
138 self.key_proj = nn.Linear(n_embd, attn_sz, bias=False)
139 self.value_proj = nn.Linear(n_embd, attn_sz, bias=False)
140 self.receptance_proj = nn.Linear(n_embd, attn_sz, bias=False)
141 self.output_proj = nn.Linear(attn_sz, n_embd, bias=False)
learnable vector
144 self.time_decay = nn.Parameter(torch.empty(attn_sz))
145 self.time_first = nn.Parameter(torch.empty(attn_sz))
146 self.time_mix_key = nn.Parameter(torch.empty(1, 1, n_embd))
147 self.time_mix_value = nn.Parameter(torch.empty(1, 1, n_embd))
148 self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, n_embd))
x = (Batch,Time,Channel)
150 def forward(self, x, state=None):
154 if state is not None:
155 prev_x = state[self.layer_id, :, [PREV_X_TIME], :]
156 state[self.layer_id, :, [PREV_X_TIME], :] = x
157 else:
158 prev_x = self.time_shift(x)
161 receptance = x * self.time_mix_receptance + prev_x * (1 - self.time_mix_receptance)
162 receptance = self.receptance_proj(receptance)
165 key = x * self.time_mix_key + prev_x * (1 - self.time_mix_key)
166 key = self.key_proj(key)
169 value = x * self.time_mix_value + prev_x * (1 - self.time_mix_value)
170 value = self.value_proj(value)
WKV calculation
173 _, seq_length, _ = key.size()
174 output = torch.zeros_like(key)
175
176 if state is None:
177 num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
178 den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
179 max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
180 else:
181 num_state = state[self.layer_id, :, NUM_STATE, :]
182 den_state = state[self.layer_id, :, DEN_STATE, :]
183 max_state = state[self.layer_id, :, MAX_STATE, :]
184
185 time_decay = -torch.exp(self.time_decay)
186
187 for current_index in range(seq_length):
188 current_key = key[:, current_index].float()
189 current_value = value[:, current_index]
192 max_for_output = torch.maximum(max_state, current_key + self.time_first)
193 e1 = torch.exp(max_state - max_for_output)
194 e2 = torch.exp(current_key + self.time_first - max_for_output)
195 numerator = e1 * num_state + e2 * current_value
196 denominator = e1 * den_state + e2
197 output[:, current_index] = (numerator / denominator).to(output.dtype)
Update state for next iteration
200 max_for_state = torch.maximum(max_state + time_decay, current_key)
201 e1 = torch.exp(max_state + time_decay - max_for_state)
202 e2 = torch.exp(current_key - max_for_state)
203 num_state = e1 * num_state + e2 * current_value
204 den_state = e1 * den_state + e2
205 max_state = max_for_state
update states
208 state[self.layer_id, :, NUM_STATE, :] = num_state
209 state[self.layer_id, :, DEN_STATE, :] = den_state
210 state[self.layer_id, :, MAX_STATE, :] = max_state
211 wkv, state = self.wkv_function(key, value, use_customized_cuda_kernel=self.config.use_customized_cuda_kernel,
212 state=state)
215 rwkv = F.sigmoid(receptance) * wkv
216 rwkv = self.output_proj(rwkv)
217
218 return rwkv, state
221class Block(Module):
226 def __init__(self, config, layer_id):
227 super().__init__()
228 self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
229 self.attn = TimeMixing(config, layer_id)
230 self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
231 self.ffn = ChannelMixing(config, layer_id)
233 def forward(self, x, state=None):
state: batch_size, 5 , n_embd
time mixing
237 residual = x
238 x, state = self.attn(self.ln_1(x), state=state)
239 x = x + residual
channel mixing
242 residual = x
243 x, state = self.ffn(self.ln_2(x), state=state)
244 x = x + residual
245 return x, state
248class RWKV(Module):
252 def __init__(self, config, lr_init=0.0008):
253 super().__init__()
254 assert config.vocab_size is not None
255 assert config.block_size is not None
256 self.config = config
257 self.lr_init = lr_init ## used to initialize embedding parameters
258 self.n_layer = config.n_layer
259 self.n_embd = config.n_embd
Initiate model layers
262 self.rwkv = nn.ModuleDict(dict(
263 wte=nn.Embedding(config.vocab_size, config.n_embd),
264 ln_p=LayerNorm(config.n_embd, bias=config.bias),
265 h=nn.ModuleList([Block(config, layer_id) for layer_id in range(config.n_layer)]),
266 ln_f=LayerNorm(config.n_embd, bias=config.bias),
267 ))
Output linear layer
270 self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
272 def forward(self, idx, targets=None, state=None, return_state=False):
273 b, t = idx.size()
274 assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
Embedding Layer
277 x = self.rwkv.wte(idx)
Layer Norm
280 x = self.rwkv.ln_p(x)
RWKV Blocks
283 for block_idx, block in enumerate(self.rwkv.h):
284 x, state = block(x, state)
285 x = self.rwkv.ln_f(x)
Logit Layer and loss Function (for training)
288 if targets is not None:
if we are given some desired targets also calculate the loss
290 logits = self.lm_head(x)
291 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
292 if self.training:
293 loss = L2Wrap.apply(loss, logits)
294 else:
inference-time mini-optimization: only forward the lm_head on the very last position
296 logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
297 loss = None
Return Logits and loss
300 if return_state:
301 return logits, loss, state
302 else:
303 return logits, loss