Fast weights transformer

The paper Linear Transformers Are Secretly Fast Weight Memory Systems in PyTorch finds similarities between linear self-attention and fast weight systems and makes modifications to self-attention update rule based on that. It also introduces a simpler, yet effective kernel function.

The authors have provided an official implementation of the paper including other variants they compare with in the paper.

Fast weights

Consider a sequence of inputs $\big\{x^{(i)}\big\}^L_{i=1}$ or length $L$ and each step is a vector of size $d_{in}$; i.e. $x \in \mathbb{R}^{d_{in}}$. The fast weight model generates a weight matrix at each step to produce output $\big\{y^{(i)}\big\}^L_{i=1}$, $y \in \mathbb{R}^{d_{out}}$

$\otimes$ is the outer product ($a \otimes b = a b^\top$), where elements of the two vectors are multiplied with each other to give a matrix. $\sigma$ is an activation function. $\color{orange}{W_a}$ and $\color{orange}{W_b}$ are trainable weights (parameters). $\color{cyan}{W^{(i)}}$ are the fast weights that are generated at each step.

Linear self-attention

Original transformer self-attention is, (omitting $\frac{1}{d_k}$ for clarity)

where $\kappa(k, q) = \text{exp}(k \cdot q)$

The idea behind linearizing self attention is to replace softmax kernel $\kappa$ with a different kernel $\kappa ‘$ so that we can calculate the denominator of the self attention function faster:

This gives

With $\color{cyan}{W^{(i)}} = \sum^i_{j=1} v^{(j)} \otimes \phi(k^{(j)})$ and $z^{(i)} = \sum^i_{j=1} \color{lightgreen}{\phi(k^{(j)})}$, we can calculate them efficiently:

This is quite similar to fast weights.

The paper introduces a new linear attention projection function $\color{lightgreen}{\phi}$ a new update rule for $\color{cyan}{W^{(i)}} = f(\color{cyan}{W^{(i-1)}})$ and change the normalization $\frac{1}{z^{(i)} \cdot \color{lightgreen}{\phi(q^{(i)})}}$

Here are the training code and a notebook for training a fast weights transformer on the Tiny Shakespeare dataset.

Open In Colab View Run

95import torch
96from torch import nn
98from labml_helpers.module import Module
99from labml_nn.transformers.feed_forward import FeedForward
100from labml_nn.transformers.mha import PrepareForMultiHeadAttention
101from labml_nn.utils import clone_module_list

Deterministic Parameter Free Project (DPFP)

This is the new projection function $\color{lightgreen}{\phi}$ introduced in the paper. DPFP projects $k$ of dimensionality $d_{key}$ to dimensionality $d_{dot} = 2 d_{key} \nu$, where $\nu \in \{1, 2, …, 2 d_{key} - 1 \}$ is a hyper-parameter.

where $\big[k, -k\big]$ is the concatenation of $k$ and $-k$ to give a vector of size $2 d_{key}$, $i \in \{1, 2, …, \nu \}$, and $j \in \{1, 2, …, 2 d_{key}\}$. $x_i$ is the $i$-th element of vector $x$ and is rolled around if $i$ is larger than the number of elements in $x$.

Basically, it creates a new vector by multiplying elements of $[k, -k]$ shifted by $i$.

This produces projections that are sparse (only a few elements of $phi$ are non-zero) and orthogonal ($\color{lightgreen}{\phi(k^{(i)})} \cdot \color{lightgreen}{\phi(k^{(j)})} \approx 0$ for most $i, j$ unless $k^{(i)}$ and $k^{(j)}$ are very similar.


Paper introduces a simple normalization for $\color{lightgreen}{\phi}$,

Check the paper for derivation.

104class DPFP(Module):
  • nu is the hyper-parameter $\nu$.
  • eps is the small value used to make sure there is no division-by-zero when normalizing.
138    def __init__(self, nu: int = 1, eps: float = 1e-6):
143        super().__init__()
144 = nu
145        self.relu = nn.ReLU()
146        self.eps = eps
148    def __call__(self, k: torch.Tensor):

Get $\color{lightgreen}{\phi(k)}$

150        k = self.dpfp(k)

Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$

152        return k / (torch.sum(k, dim=-1, keepdim=True) + self.eps)

154    def dpfp(self, k: torch.Tensor):

$x = \text{ReLU}\Big(\big[k, -k\big]\Big)$

159        x = self.relu([k, -k], dim=-1))

Shift and roll by $i \in \{1, 2, …, \nu \}$, to get

162        x_rolled = [x.roll(shifts=i, dims=-1) for i in range(1, + 1)]

Concatenate to get

165        x_rolled =, dim=-1)

Concatenate copies of $x$

167        x_repeat =[x] *, dim=-1)

Multiply them,

173        return x_repeat * x_rolled

Fast Weights Attention

The paper introduces a new update rule for calculating $\color{cyan}{W^{(i)}}$. The model first retrieves the current value $\bar{v}^{(i)}$ paired with the key $k^{(i)}$. Then stores a combination $v^{(i)}_{new}$ of the retrieved value $\bar{v}^{̄(i)}$ and the input $v^{(i)}$.

where $\color{orange}{W_\beta}$ is a trainable parameter and $\sigma$ is the sigmoid function.

Note that we don’t need the normalization term $z$ because $\color{lightgreen}{\phi’}$ is normalized.

176class FastWeightsAttention(Module):
204    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
205        super().__init__()

Number of features per head $d_k$

208        self.d_k = d_model // heads

Number of heads

210        self.heads = heads

These transform the query, key and value multi-headed attention.

213        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
214        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
215        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

Interpolation weight function $\sigma \Big(\color{orange}{W_\beta} x^{(i)} \Big)$ for each head

218        self.interpolation_weight = nn.Sequential(
219            PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
220            nn.Sigmoid()
221        )


224        self.phi = phi

Output layer

227        self.output = nn.Linear(d_model, d_model)


229        self.dropout = nn.Dropout(dropout_prob)
231    def __call__(self, x: torch.Tensor):

Get the number of steps $L$

233        seq_len = x.shape[0]

$\color{lightgreen}{\phi’(q^{(i)})}$ for all steps and heads

235        query = self.phi(self.query(x))

$\color{lightgreen}{\phi’(k^{(i)})}$ for all steps and heads

237        key = self.phi(self.key(x))

$v^{(i)}$ for all steps and heads

239        value = self.value(x)

$\beta^{(i)}$ for all steps and heads

241        beta = self.interpolation_weight(x)


244        weights = key.new_zeros((key.shape[1], key.shape[2], value.shape[3], key.shape[3]))

List to store outputs $y^{(i)}$

246        outputs = []

Iterate through steps

249        for i in range(seq_len):

251            value_existing = torch.einsum('bhvk,bhk->bhv', weights, key[i])

256            weights = weights + torch.einsum('bhv,bhk->bhvk', beta[i] * (value[i] - value_existing), key[i])

259            y = torch.einsum('bhvk,bhk->bhv', weights, query[i])

Merge multiple heads and append to outputs

262            outputs.append(y.reshape(y.shape[0], -1))

Stack outputs at each step into a single tensor

265        x = torch.stack(outputs)

Output layer

268        return self.output(x)

This is a general transformer layer that combines self attention and feedforward network.

271class FastWeightsAttentionTransformerLayer(Module):
275    def __init__(self, *,
276                 d_model: int,
277                 attn: FastWeightsAttention,
278                 feed_forward: FeedForward,
279                 dropout_prob: float):
280        super().__init__()

Transformer size $d_{model}$

282        self.size = d_model

Fast weights attention module

284        self.attn = attn

Feed-forward network

286        self.feed_forward = feed_forward

Dropout layer

288        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

291        self.norm_self_attn = nn.LayerNorm([d_model])
292        self.norm_ff = nn.LayerNorm([d_model])
294    def __call__(self, x: torch.Tensor):

Calculate fast weights self attention

296        attn = self.attn(x)

Add the self attention results

298        x = x + self.dropout(attn)

Normalize for feed-forward

301        z = self.norm_ff(x)

Pass through the feed-forward network

303        ff = self.feed_forward(z)

Add the feed-forward results back

305        x = x + self.dropout(ff)
308        return x

This is a general transformer module with multiple transformer layers

311class FastWeightsAttentionTransformer(Module):
315    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
316        super().__init__()

Make copies of the transformer layer

318        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

320        self.norm = nn.LayerNorm([layer.size])
322    def __call__(self, x: torch.Tensor):
323        for i, layer in enumerate(self.layers):

Get layer output

325            x = layer(x)

Normalize the output

328        return self.norm(x)