The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.
The authors have provided an official implementation of the paper including other variants they compare with in the paper.
Consider a sequence of inputs or length and each step is a vector of size ; i.e. . The fast weight model generates a weight matrix at each step to produce output ,
is the outer product (), where elements of the two vectors are multiplied with each other to give a matrix. is an activation function. and are trainable weights (parameters). are the fast weights that are generated at each step.
Original transformer self-attention is, (omitting for clarity)
The idea behind linearizing self attention is to replace softmax kernel with a different kernel so that we can calculate the denominator of the self attention function faster:
With and , we can calculate them efficiently:
This is quite similar to fast weights.
The paper introduces a new linear attention projection function a new update rule for and change the normalization
Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.
95import torch 96from torch import nn 97 98from labml_helpers.module import Module 99from labml_nn.transformers.feed_forward import FeedForward 100from labml_nn.transformers.mha import PrepareForMultiHeadAttention 101from labml_nn.utils import clone_module_list
This is the new projection function introduced in the paper. DPFP projects of dimensionality to dimensionality , where is a hyper-parameter.
where is the concatenation of and to give a vector of size , , and . is the -th element of vector and is rolled around if is larger than the number of elements in .
Basically, it creates a new vector by multiplying elements of shifted by .
This produces projections that are sparse (only a few elements of are non-zero) and orthogonal ( for most unless and are very similar.
Paper introduces a simple normalization for ,
Check the paper for derivation.
nuis the hyper-parameter .
epsis the small value used to make sure there is no division-by-zero when normalizing.
138 def __init__(self, nu: int = 1, eps: float = 1e-6):
143 super().__init__() 144 self.nu = nu 145 self.relu = nn.ReLU() 146 self.eps = eps
148 def forward(self, k: torch.Tensor):
150 k = self.dpfp(k)
152 return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)
154 def dpfp(self, k: torch.Tensor):
159 x = self.relu(torch.cat([k, -k], dim=-1))
Shift and roll by , to get
162 x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, self.nu + 1)]
Concatenate to get
165 x_rolled = torch.cat(x_rolled, dim=-1)
Concatenate copies of
167 x_repeat = torch.cat([x] * self.nu, dim=-1)
173 return x_repeat * x_rolled
The paper introduces a new update rule for calculating . The model first retrieves the current value paired with the key . Then stores a combination of the retrieved value and the input .
where is a trainable parameter and is the sigmoid function.
Note that we don't need the normalization term because is normalized.
204 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP): 205 super().__init__()
Number of features per head
208 self.d_k = d_model // heads
Number of heads
210 self.heads = heads
These transform the
213 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) 214 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False) 215 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
Interpolation weight function for each head
218 self.interpolation_weight = nn.Sequential( 219 PrepareForMultiHeadAttention(d_model, heads, 1, bias=False), 220 nn.Sigmoid() 221 )
224 self.phi = phi
227 self.output = nn.Linear(d_model, d_model)
229 self.dropout = nn.Dropout(dropout_prob)
231 def forward(self, x: torch.Tensor):
Get the number of steps
233 seq_len = x.shape
235 query = self.phi(self.query(x))
237 key = self.phi(self.key(x))
239 value = self.value(x)
241 beta = self.interpolation_weight(x)
244 weights = key.new_zeros((key.shape, key.shape, value.shape, key.shape))
List to store outputs
246 outputs = 
Iterate through steps
249 for i in range(seq_len):
251 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])
256 weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])
259 y = torch.einsum('bhvk,bhk->bhv', weights, query[i])
Merge multiple heads and append to
262 outputs.append(y.reshape(y.shape, -1))
Stack outputs at each step into a single tensor
265 x = torch.stack(outputs)
268 return self.output(x)
This is a general transformer layer that combines self attention and feedforward network.
275 def __init__(self, *, 276 d_model: int, 277 attn: FastWeightsAttention, 278 feed_forward: FeedForward, 279 dropout_prob: float): 280 super().__init__()
282 self.size = d_model
Fast weights attention module
284 self.attn = attn
286 self.feed_forward = feed_forward
288 self.dropout = nn.Dropout(dropout_prob)
291 self.norm_self_attn = nn.LayerNorm([d_model]) 292 self.norm_ff = nn.LayerNorm([d_model])
294 def forward(self, x: torch.Tensor):
Calculate fast weights self attention
296 attn = self.attn(x)
Add the self attention results
298 x = x + self.dropout(attn)
Normalize for feed-forward
301 z = self.norm_ff(x)
Pass through the feed-forward network
303 ff = self.feed_forward(z)
Add the feed-forward results back
305 x = x + self.dropout(ff)
308 return x
This is a general transformer module with multiple transformer layers
315 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int): 316 super().__init__()
Make copies of the transformer layer
318 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
320 self.norm = nn.LayerNorm([layer.size])
322 def forward(self, x: torch.Tensor): 323 for i, layer in enumerate(self.layers):
Get layer output
325 x = layer(x)
Normalize the output
328 return self.norm(x)