Fast weights transformer

The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.

The authors have provided an official implementation of the paper including other variants they compare with in the paper.

Fast weights

Consider a sequence of inputs or length and each step is a vector of size ; i.e. . The fast weight model generates a weight matrix at each step to produce output ,

is the outer product (), where elements of the two vectors are multiplied with each other to give a matrix. is an activation function. and are trainable weights (parameters). are the fast weights that are generated at each step.

Linear self-attention

Original transformer self-attention is, (omitting for clarity)


The idea behind linearizing self attention is to replace softmax kernel with a different kernel so that we can calculate the denominator of the self attention function faster:

This gives

With and , we can calculate them efficiently:

This is quite similar to fast weights.

The paper introduces a new linear attention projection function a new update rule for and change the normalization

Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.

Open In Colab

95import torch
96from torch import nn
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list

Deterministic Parameter Free Project (DPFP)

This is the new projection function introduced in the paper. DPFP projects of dimensionality to dimensionality , where is a hyper-parameter.

where is the concatenation of and to give a vector of size , , and . is the -th element of vector and is rolled around if is larger than the number of elements in .

Basically, it creates a new vector by multiplying elements of shifted by .

This produces projections that are sparse (only a few elements of are non-zero) and orthogonal ( for most unless and are very similar.


Paper introduces a simple normalization for ,

Check the paper for derivation.

104class DPFP(Module):
  • nu is the hyper-parameter .
  • eps is the small value used to make sure there is no division-by-zero when normalizing.
138    def __init__(self, nu: int = 1, eps: float = 1e-6):
143        super().__init__()
144 = nu
145        self.relu = nn.ReLU()
146        self.eps = eps
148    def forward(self, k: torch.Tensor):


150        k = self.dpfp(k)

Normalize by

152        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)

154    def dpfp(self, k: torch.Tensor):

159        x = self.relu([k, -k], dim=-1))

Shift and roll by , to get

162        x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, + 1)]

Concatenate to get

165        x_rolled =, dim=-1)

Concatenate copies of

167        x_repeat =[x] *, dim=-1)

Multiply them,

173        return x_repeat * x_rolled

Fast Weights Attention

The paper introduces a new update rule for calculating . The model first retrieves the current value paired with the key . Then stores a combination of the retrieved value and the input .

where is a trainable parameter and is the sigmoid function.

Note that we don't need the normalization term because is normalized.

176class FastWeightsAttention(Module):
204    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205        super().__init__()

Number of features per head

208        self.d_k = d_model // heads

Number of heads

210        self.heads = heads

These transform the query , key and value multi-headed attention.

213        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

Interpolation weight function for each head

218        self.interpolation_weight = nn.Sequential(
219            PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220            nn.Sigmoid()
221        )

224        self.phi = phi

Output layer

227        self.output = nn.Linear(d_model, d_model)


229        self.dropout = nn.Dropout(dropout_prob)
231    def forward(self, x: torch.Tensor):

Get the number of steps

233        seq_len = x.shape[0]

for all steps and heads

235        query = self.phi(self.query(x))

for all steps and heads

237        key = self.phi(self.key(x))

for all steps and heads

239        value = self.value(x)

for all steps and heads

241        beta = self.interpolation_weight(x)

244        weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))

List to store outputs

246        outputs = []

Iterate through steps

249        for i in range(seq_len):

251            value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])

256            weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])

259            y = torch.einsum('bhvk,bhk->bhv', weights, query[i])

Merge multiple heads and append to outputs

262            outputs.append(y.reshape(y.shape[0], -1))

Stack outputs at each step into a single tensor

265        x = torch.stack(outputs)

Output layer

268        return self.output(x)

This is a general transformer layer that combines self attention and feedforward network.

271class FastWeightsAttentionTransformerLayer(Module):
275    def __init__(self, *,
276                 d_model: int,
277                 attn: FastWeightsAttention,
278                 feed_forward: FeedForward,
279                 dropout_prob: float):
280        super().__init__()

Transformer size

282        self.size = d_model

Fast weights attention module

284        self.attn = attn

Feed-forward network

286        self.feed_forward = feed_forward

Dropout layer

288        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

291        self.norm_self_attn = nn.LayerNorm([d_model])
292        self.norm_ff = nn.LayerNorm([d_model])
294    def forward(self, x: torch.Tensor):

Calculate fast weights self attention

296        attn = self.attn(x)

Add the self attention results

298        x = x + self.dropout(attn)

Normalize for feed-forward

301        z = self.norm_ff(x)

Pass through the feed-forward network

303        ff = self.feed_forward(z)

Add the feed-forward results back

305        x = x + self.dropout(ff)

308        return x

This is a general transformer module with multiple transformer layers

311class FastWeightsAttentionTransformer(Module):
315    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316        super().__init__()

Make copies of the transformer layer

318        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

320        self.norm = nn.LayerNorm([layer.size])
322    def forward(self, x: torch.Tensor):
323        for i, layer in enumerate(self.layers):

Get layer output

325            x = layer(x)

Normalize the output

328        return self.norm(x)