Feedback Transformer

This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.

Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.

In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.

The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.

The updated feedback transformer shares weights and used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.

Here's the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab

42import math
43from typing import Optional
44
45import torch
46from torch import nn
47
48from labml_helpers.module import Module
49from labml_nn.transformers.feed_forward import FeedForward
50from labml_nn.transformers.mha import PrepareForMultiHeadAttention
51from labml_nn.utils import clone_module_list

Feedback Attention

This module computes recurrent attention similar to attention from original transformers paper.

54class FeedbackAttention(Module):
  • 'heads' is the number of attention heads
  • d_model is the number of features in the transformer
  • dropout_prob is the attention dropout probability
  • is_kv_precomputed is whether key, value tensors are already calculated
65    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
66                 is_kv_precomputed: bool = False):
74        super().__init__()

Number of features per head

77        self.d_k = d_model // heads

79        self.heads = heads

These transform the query multi-headed attention.

82        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

These transform the key and value for multi-headed attention.

84        if not is_kv_precomputed:
85            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
86            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Keys and values are already calculated

88        else:
89            self.key = None
90            self.value = None

Output layer

93        self.output = nn.Linear(d_model, d_model)

Dropout

95        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

97        self.scale = 1 / math.sqrt(self.d_k)

Softmax for attention along the time dimension of key

100        self.softmax = nn.Softmax(dim=0)

Number of relative positions

103        self.P = 2 ** 12

Relative positional embeddings for key relative to the query.

106        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)

Relative positional embedding bias for key relative to the query.

108        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

110        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

We store attentions so that it can be used for logging, or other computations if needed

113        self.attn = None

Get attention scores

We use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.

Attention from current step's query to key in step (relative to current step) is,

where , are linear transformations of original embeddings and are linear transformations of positional encodings .

We replace term with .

115    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

143        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]

145        query_pos_bias = self.query_pos_bias[None, :, :]

147        key_pos_bias = self.key_pos_bias[-key.shape[0]:]

150        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)

152        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]

155        return ac + bd
  • query has shape [batch_size, d_model]
  • key and value has shape [seq_len, batch_size, d_model]
157    def forward(self, *,
158                query: torch.Tensor,
159                key: torch.Tensor,
160                value: torch.Tensor):

Prepare query , key and value for attention computation key and value will then have shape [seq_len, batch_size, heads, d_k] and query will have shape [batch_size, heads, d_k]

169        query = self.query(query)
170        if self.key:
171            key = self.key(key)
172        if self.value:
173            value = self.value(value)

Compute attention scores. Results in a tensor of shape [seq_len, batch_size, heads]

177        scores = self.get_scores(query, key)

Scale scores

180        scores *= self.scale

Softmax

183        attn = self.softmax(scores)

Apply dropout

186        attn = self.dropout(attn)

Multiply by the values

189        x = torch.einsum("jbh,jbhd->bhd", attn, value)

Concatenate multiple heads

192        x = x.reshape(x.shape[0], -1)

Output layer

195        return self.output(x)

Feedback Transformer Layer

This implements a single transformer layer in the feedback transformer.

198class FeedbackTransformerLayer(Module):
  • d_model is the number of features in the transformer
  • attn is the feedback attention module
  • feed_forward is the position-wise feed forward layer
  • dropout_prob is the dropout probability for dropout layers after attention and feed-forward
205    def __init__(self, *,
206                 d_model: int,
207                 attn: FeedbackAttention,
208                 feed_forward: FeedForward,
209                 dropout_prob: float):
216        super().__init__()

Transformer size

218        self.size = d_model

220        self.attn = attn
221        self.feed_forward = feed_forward
222        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

225        self.norm_self_attn = nn.LayerNorm([d_model])
226        self.norm_ff = nn.LayerNorm([d_model])
228    def forward(self, *,
229                x: torch.Tensor,
230                key: Optional[torch.Tensor],
231                value: Optional[torch.Tensor]):

If there is memory

233        if key is not None:

Normalize the vectors before doing self attention

235            z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

237            self_attn = self.attn(query=z, key=key, value=value)

Add the self attention results

239            x = x + self.dropout(self_attn)

Normalize for feed-forward

242        z = self.norm_ff(x)

Pass through the feed-forward network

244        ff = self.feed_forward(z)

Add the feed-forward results back

246        x = x + self.dropout(ff)

249        return x

Feedback Transformer Module

252class FeedbackTransformer(Module):
  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
257    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
263        super().__init__()

Make copies of the transformer layer

265        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

267        self.norm = nn.LayerNorm([layer.size])

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

270        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax for weights before taking the weighted sum

272        self.softmax = nn.Softmax(0)
  • x_seq is the input with shape [seq_len, batch_size, d_model]
274    def forward(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

280        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

282        res = []

List to store the memory vectors

284        mem = []

For each input step

286        for x in x_seq:

List to store layer outputs

288            layer_outputs = [x]

If there is memory, stack them into a vector

291            mem_tensor = torch.stack(mem) if mem else None

Run through each layer

294            for layer in self.layers:

Get layer output

296                x = layer(x=x, key=mem_tensor, value=mem_tensor)

Append them to the list of layer outputs

298                layer_outputs.append(x)

Stack the layer outputs to a tensor

301            layer_outputs = torch.stack(layer_outputs)

Calculate the memory vector as a weighted sum of layer outputs

303            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))

Append the output to results

305            res.append(x)

Stack the output tensors

308        res = torch.stack(res)

Normalize the output

310        return self.norm(res)

Shared keys and values among layers

Stack Function implementation

We implement a custom function instead of appending to a python list and then doing torch.stack . This greatly improves the performance over calling torch.stack at each step along the sequence. Everytime torch.stack is called, it creates a new tensor, while this method and the accompanying class Stack share memory for each step.

317class StackFunction(torch.autograd.Function):
  • ctx is the context of the function (which lets us cache stuff)
  • memory is the shared memory tensor where we stack and store the values of each step (keys & values)
  • memory_grad is the shared memory tensor to store and accumulate gradients of each step
  • last is the last value stacked
  • n is the number of steps (i.e. size of the stack)

This returns the stacked tensor for steps upto n .

329    @staticmethod
330    def forward(ctx, memory, memory_grad, last, n):

Cache accumulated gradients

342        ctx._mem_grad = memory_grad

Cache the size of the stack

344        ctx._n = n

Return the stack

346        return memory[:n + 1]
  • grad_output is the gradient with respect to the output of about forward function

This accumulates the gradients in the shared memory tensor and return the gradients with respect to the last result in the stack.

348    @staticmethod
349    def backward(ctx, grad_output):

Get the current size of the stack

357        n = ctx._n

Get the accumulated gradients

359        memory_grad = ctx._mem_grad

Add the gradients

361        memory_grad[:n + 1] += grad_output

Return the gradients w.r.t to last value in the stack

363        return None, None, memory_grad[n], None

Stack Module

This uses the stack function defined above, and does the necessary initializations.

366class Stack:
  • max_len is the maximum size of the stack
373    def __init__(self, max_len: int):
377        self.max_len = max_len
378        self.memory = None
379        self.memory_grad = None
380        self.last = None
381        self.n = -1
382        self.last_get_n = -1
  • n is the size of the stack
  • value is the tensor that needs to be added to the stack
384    def append(self, n: int, value: torch.Tensor):

You need to get (use) the stack after adding a value. Otherwise this implementation fails

392        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"

Do this without gradients

395        with torch.no_grad():

Initialize the shared memory tensor to keep the stack

397            if self.memory is None or self.memory.shape[1:] != value.shape:

This should only happen when the stack is empty

399                assert n == 0

Create a tensor for the stack

401                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)

Create a tensor to accumulate the gradients

403                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)

The memory is already initialized but we are resetting the stack.

This could have been another function like reset , but we found this easier to use.

408            elif n == 0:

Reset accumulated gradients

410                self.memory_grad.fill_(0.)

Set the value in the correct position of the stack

413            self.memory.data[n] = value.detach()

Keep track of the stack (for debugging)

415            self.n = n

Keep track of the last value added to the stack. We need this to be passed on to StackFunction in order to get the gradients propagated backwards.

420        self.last = value

Returns the stack

422    def get(self):

Keep track of the size of the stack when it was used. This is used for a sanity check in append .

429        self.last_get_n = self.n

Take it all through StackFunction so that StackFunction.backwards is called by PyTorch during backpropagation.

432        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)

To release memory

434    def free(self):
439        self.memory = None
440        self.memory_grad = None
441        self.last = None

Updated Feedback Transformer Module

This is the updated feedback transformer module that caches the keys and values.

444class FeedbackTransformerKV(Module):
  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
  • d_model is the number of features in the transformer
  • 'heads' is the number of attention heads
451    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
459        super().__init__()

Make copies of the transformer layer

461        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

463        self.norm = nn.LayerNorm([layer.size])

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

466        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax for weights before taking the weighted sum

468        self.softmax = nn.Softmax(0)

Number of features in a head

471        d_k = d_model // heads

Module to transform embeddings (memory) to get keys

473        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

Module to transform embeddings (memory) to get keys

475        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

Memory for stacked keys

478        self.mem_key = Stack(512)

Memory for stacked values

480        self.mem_value = Stack(512)
  • x_seq is the input with shape [seq_len, batch_size, d_model]
482    def forward(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

488        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

490        res = []

For each input step

492        for step, x in enumerate(x_seq):

List to store layer outputs

494            layer_outputs = [x]

Stack of keys and values

497            key_tensor = None
498            value_tensor = None

Get the keys and values tensors if we are beyond the initial step

500            if step > 0:
501                key_tensor = self.mem_key.get()
502                value_tensor = self.mem_value.get()

Run through each layer

505            for layer in self.layers:

Get layer output

507                x = layer(x=x, key=key_tensor, value=value_tensor)

Append them to the list of layer outputs

509                layer_outputs.append(x)

Stack the layer outputs to a tensor

512            layer_outputs = torch.stack(layer_outputs)

Calculate the memory vector as a weighted sum of layer outputs

514            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))

Calculate the keys from memory and add it to the stack

516            self.mem_key.append(step, self.key(mem))

Calculate the values from memory and add it to the stack

518            self.mem_value.append(step, self.value(mem))

Append the output to results

520            res.append(x)

Stack the output tensors

523        res = torch.stack(res)

Normalize the output

525        return self.norm(res)
527    def free(self):
528        self.mem_key.free()
529        self.mem_value.free()