The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.
The authors have provided an official implementation of the paper including other variants they compare with in the paper.
Consider a sequence of inputs or length and each step is a vector of size ; i.e. . The fast weight model generates a weight matrix at each step to produce output ,
is the outer product (), where elements of the two vectors are multiplied with each other to give a matrix. is an activation function. and are trainable weights (parameters). are the fast weights that are generated at each step.
Original transformer self-attention is, (omitting for clarity)
where
The idea behind linearizing self attention is to replace softmax kernel with a different kernel so that we can calculate the denominator of the self attention function faster:
This gives
With and , we can calculate them efficiently:
This is quite similar to fast weights.
The paper introduces a new linear attention projection function a new update rule for and change the normalization
Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.
95import torch
96from torch import nn
97
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list
This is the new projection function introduced in the paper. DPFP projects of dimensionality to dimensionality , where is a hyper-parameter.
where is the concatenation of and to give a vector of size , , and . is the -th element of vector and is rolled around if is larger than the number of elements in .
Basically, it creates a new vector by multiplying elements of shifted by .
This produces projections that are sparse (only a few elements of are non-zero) and orthogonal ( for most unless and are very similar.
Paper introduces a simple normalization for ,
Check the paper for derivation.
104class DPFP(Module):
nu
is the hyper-parameter . eps
is the small value used to make sure there is no division-by-zero when normalizing.138 def __init__(self, nu: int = 1, eps: float = 1e-6):
143 super().__init__()
144 self.nu = nu
145 self.relu = nn.ReLU()
146 self.eps = eps
148 def forward(self, k: torch.Tensor):
Get
150 k = self.dpfp(k)
Normalize by
152 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
154 def dpfp(self, k: torch.Tensor):
159 x = self.relu(torch.cat([k, -k], dim=-1))
Shift and roll by , to get
162 x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]
Concatenate to get
165 x_rolled = torch.cat(x_rolled, dim=-1)
Concatenate copies of
167 x_repeat = torch.cat([x] * self.nu, dim=-1)
Multiply them,
173 return x_repeat * x_rolled
The paper introduces a new update rule for calculating . The model first retrieves the current value paired with the key . Then stores a combination of the retrieved value and the input .
where is a trainable parameter and is the sigmoid function.
Note that we don't need the normalization term because is normalized.
176class FastWeightsAttention(Module):
204 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205 super().__init__()
Number of features per head
208 self.d_k = d_model // heads
Number of heads
210 self.heads = heads
These transform the query
, key
and value
multi-headed attention.
213 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
Interpolation weight function for each head
218 self.interpolation_weight = nn.Sequential(
219 PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220 nn.Sigmoid()
221 )
224 self.phi = phi
Output layer
227 self.output = nn.Linear(d_model, d_model)
Dropout
229 self.dropout = nn.Dropout(dropout_prob)
231 def forward(self, x: torch.Tensor):
Get the number of steps
233 seq_len = x.shape[0]
for all steps and heads
235 query = self.phi(self.query(x))
for all steps and heads
237 key = self.phi(self.key(x))
for all steps and heads
239 value = self.value(x)
for all steps and heads
241 beta = self.interpolation_weight(x)
244 weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))
List to store outputs
246 outputs = []
Iterate through steps
249 for i in range(seq_len):
251 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])
256 weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])
259 y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
Merge multiple heads and append to outputs
262 outputs.append(y.reshape(y.shape[0], -1))
Stack outputs at each step into a single tensor
265 x = torch.stack(outputs)
Output layer
268 return self.output(x)
This is a general transformer layer that combines self attention and feedforward network.
271class FastWeightsAttentionTransformerLayer(Module):
275 def __init__(self, *,
276 d_model: int,
277 attn: FastWeightsAttention,
278 feed_forward: FeedForward,
279 dropout_prob: float):
280 super().__init__()
Transformer size
282 self.size = d_model
Fast weights attention module
284 self.attn = attn
Feed-forward network
286 self.feed_forward = feed_forward
Dropout layer
288 self.dropout = nn.Dropout(dropout_prob)
Normalization layers
291 self.norm_self_attn = nn.LayerNorm([d_model])
292 self.norm_ff = nn.LayerNorm([d_model])
294 def forward(self, x: torch.Tensor):
Calculate fast weights self attention
296 attn = self.attn(x)
Add the self attention results
298 x = x + self.dropout(attn)
Normalize for feed-forward
301 z = self.norm_ff(x)
Pass through the feed-forward network
303 ff = self.feed_forward(z)
Add the feed-forward results back
305 x = x + self.dropout(ff)
308 return x
This is a general transformer module with multiple transformer layers
311class FastWeightsAttentionTransformer(Module):
315 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316 super().__init__()
Make copies of the transformer layer
318 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
320 self.norm = nn.LayerNorm([layer.size])
322 def forward(self, x: torch.Tensor):
323 for i, layer in enumerate(self.layers):
Get layer output
325 x = layer(x)
Normalize the output
328 return self.norm(x)