Feedback Transformer

This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.

Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.

In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.

The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.

The updated feedback transformer shares weights and used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.

Here's the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.

Open In Colab View Run

43import math
44from typing import Optional
45
46import torch
47from torch import nn
48
49from labml_helpers.module import Module
50from labml_nn.transformers.feed_forward import FeedForward
51from labml_nn.transformers.mha import PrepareForMultiHeadAttention
52from labml_nn.utils import clone_module_list

Feedback Attention

This module computes recurrent attention similar to attention from original transformers paper.

55class FeedbackAttention(Module):
  • 'heads' is the number of attention heads
  • d_model is the number of features in the transformer
  • dropout_prob is the attention dropout probability
  • is_kv_precomputed is whether key, value tensors are already calculated
66    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
67                 is_kv_precomputed: bool = False):
75        super().__init__()

Number of features per head

78        self.d_k = d_model // heads

80        self.heads = heads

These transform the query multi-headed attention.

83        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

These transform the key and value for multi-headed attention.

85        if not is_kv_precomputed:
86            self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
87            self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

Keys and values are already calculated

89        else:
90            self.key = None
91            self.value = None

Output layer

94        self.output = nn.Linear(d_model, d_model)

Dropout

96        self.dropout = nn.Dropout(dropout_prob)

Scaling factor before the softmax

98        self.scale = 1 / math.sqrt(self.d_k)

Softmax for attention along the time dimension of key

101        self.softmax = nn.Softmax(dim=0)

Number of relative positions

104        self.P = 2 ** 12

Relative positional embeddings for key relative to the query.

107        self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)

Relative positional embedding bias for key relative to the query.

109        self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)

Positional embeddings for the query is independent of the position of the query

111        self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)

We store attentions so that it can be used for logging, or other computations if needed

114        self.attn = None

Get attention scores

We use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.

Attention from current step's query to key in step (relative to current step) is,

where , are linear transformations of original embeddings and are linear transformations of positional encodings .

We replace term with .

116    def get_scores(self, query: torch.Tensor, key: torch.Tensor):

144        key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]

146        query_pos_bias = self.query_pos_bias[None, :, :]

148        key_pos_bias = self.key_pos_bias[-key.shape[0]:]

151        ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)

153        bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]

156        return ac + bd
  • query has shape [batch_size, d_model]
  • key and value has shape [seq_len, batch_size, d_model]
158    def forward(self, *,
159                query: torch.Tensor,
160                key: torch.Tensor,
161                value: torch.Tensor):

Prepare query , key and value for attention computation key and value will then have shape [seq_len, batch_size, heads, d_k] and query will have shape [batch_size, heads, d_k]

170        query = self.query(query)
171        if self.key:
172            key = self.key(key)
173        if self.value:
174            value = self.value(value)

Compute attention scores. Results in a tensor of shape [seq_len, batch_size, heads]

178        scores = self.get_scores(query, key)

Scale scores

181        scores *= self.scale

Softmax

184        attn = self.softmax(scores)

Apply dropout

187        attn = self.dropout(attn)

Multiply by the values

190        x = torch.einsum("jbh,jbhd->bhd", attn, value)

Concatenate multiple heads

193        x = x.reshape(x.shape[0], -1)

Output layer

196        return self.output(x)

Feedback Transformer Layer

This implements a single transformer layer in the feedback transformer.

199class FeedbackTransformerLayer(Module):
  • d_model is the number of features in the transformer
  • attn is the feedback attention module
  • feed_forward is the position-wise feed forward layer
  • dropout_prob is the dropout probability for dropout layers after attention and feed-forward
206    def __init__(self, *,
207                 d_model: int,
208                 attn: FeedbackAttention,
209                 feed_forward: FeedForward,
210                 dropout_prob: float):
217        super().__init__()

Transformer size

219        self.size = d_model

221        self.attn = attn
222        self.feed_forward = feed_forward
223        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

226        self.norm_self_attn = nn.LayerNorm([d_model])
227        self.norm_ff = nn.LayerNorm([d_model])
229    def forward(self, *,
230                x: torch.Tensor,
231                key: Optional[torch.Tensor],
232                value: Optional[torch.Tensor]):

If there is memory

234        if key is not None:

Normalize the vectors before doing self attention

236            z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

238            self_attn = self.attn(query=z, key=key, value=value)

Add the self attention results

240            x = x + self.dropout(self_attn)

Normalize for feed-forward

243        z = self.norm_ff(x)

Pass through the feed-forward network

245        ff = self.feed_forward(z)

Add the feed-forward results back

247        x = x + self.dropout(ff)

250        return x

Feedback Transformer Module

253class FeedbackTransformer(Module):
  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
258    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
264        super().__init__()

Make copies of the transformer layer

266        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

268        self.norm = nn.LayerNorm([layer.size])

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

271        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax for weights before taking the weighted sum

273        self.softmax = nn.Softmax(0)
  • x_seq is the input with shape [seq_len, batch_size, d_model]
275    def forward(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

281        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

283        res = []

List to store the memory vectors

285        mem = []

For each input step

287        for x in x_seq:

List to store layer outputs

289            layer_outputs = [x]

If there is memory, stack them into a vector

292            mem_tensor = torch.stack(mem) if mem else None

Run through each layer

295            for layer in self.layers:

Get layer output

297                x = layer(x=x, key=mem_tensor, value=mem_tensor)

Append them to the list of layer outputs

299                layer_outputs.append(x)

Stack the layer outputs to a tensor

302            layer_outputs = torch.stack(layer_outputs)

Calculate the memory vector as a weighted sum of layer outputs

304            mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))

Append the output to results

306            res.append(x)

Stack the output tensors

309        res = torch.stack(res)

Normalize the output

311        return self.norm(res)

Shared keys and values among layers

Stack Function implementation

We implement a custom function instead of appending to a python list and then doing torch.stack . This greatly improves the performance over calling torch.stack at each step along the sequence. Everytime torch.stack is called, it creates a new tensor, while this method and the accompanying class Stack share memory for each step.

318class StackFunction(torch.autograd.Function):
  • ctx is the context of the function (which lets us cache stuff)
  • memory is the shared memory tensor where we stack and store the values of each step (keys & values)
  • memory_grad is the shared memory tensor to store and accumulate gradients of each step
  • last is the last value stacked
  • n is the number of steps (i.e. size of the stack)

This returns the stacked tensor for steps upto n .

330    @staticmethod
331    def forward(ctx, memory, memory_grad, last, n):

Cache accumulated gradients

343        ctx._mem_grad = memory_grad

Cache the size of the stack

345        ctx._n = n

Return the stack

347        return memory[:n + 1]
  • grad_output is the gradient with respect to the output of about forward function

This accumulates the gradients in the shared memory tensor and return the gradients with respect to the last result in the stack.

349    @staticmethod
350    def backward(ctx, grad_output):

Get the current size of the stack

358        n = ctx._n

Get the accumulated gradients

360        memory_grad = ctx._mem_grad

Add the gradients

362        memory_grad[:n + 1] += grad_output

Return the gradients w.r.t to last value in the stack

364        return None, None, memory_grad[n], None

Stack Module

This uses the stack function defined above, and does the necessary initializations.

367class Stack:
  • max_len is the maximum size of the stack
374    def __init__(self, max_len: int):
378        self.max_len = max_len
379        self.memory = None
380        self.memory_grad = None
381        self.last = None
382        self.n = -1
383        self.last_get_n = -1
  • n is the size of the stack
  • value is the tensor that needs to be added to the stack
385    def append(self, n: int, value: torch.Tensor):

You need to get (use) the stack after adding a value. Otherwise this implementation fails

393        assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"

Do this without gradients

396        with torch.no_grad():

Initialize the shared memory tensor to keep the stack

398            if self.memory is None or self.memory.shape[1:] != value.shape:

This should only happen when the stack is empty

400                assert n == 0

Create a tensor for the stack

402                self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)

Create a tensor to accumulate the gradients

404                self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)

The memory is already initialized but we are resetting the stack.

This could have been another function like reset , but we found this easier to use.

409            elif n == 0:

Reset accumulated gradients

411                self.memory_grad.fill_(0.)

Set the value in the correct position of the stack

414            self.memory.data[n] = value.detach()

Keep track of the stack (for debugging)

416            self.n = n

Keep track of the last value added to the stack. We need this to be passed on to StackFunction in order to get the gradients propagated backwards.

421        self.last = value

Returns the stack

423    def get(self):

Keep track of the size of the stack when it was used. This is used for a sanity check in append .

430        self.last_get_n = self.n

Take it all through StackFunction so that StackFunction.backwards is called by PyTorch during backpropagation.

433        return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)

To release memory

435    def free(self):
440        self.memory = None
441        self.memory_grad = None
442        self.last = None

Updated Feedback Transformer Module

This is the updated feedback transformer module that caches the keys and values.

445class FeedbackTransformerKV(Module):
  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
  • d_model is the number of features in the transformer
  • 'heads' is the number of attention heads
452    def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
460        super().__init__()

Make copies of the transformer layer

462        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

464        self.norm = nn.LayerNorm([layer.size])

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

467        self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)

Softmax for weights before taking the weighted sum

469        self.softmax = nn.Softmax(0)

Number of features in a head

472        d_k = d_model // heads

Module to transform embeddings (memory) to get keys

474        self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

Module to transform embeddings (memory) to get keys

476        self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)

Memory for stacked keys

479        self.mem_key = Stack(512)

Memory for stacked values

481        self.mem_value = Stack(512)
  • x_seq is the input with shape [seq_len, batch_size, d_model]
483    def forward(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

489        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

491        res = []

For each input step

493        for step, x in enumerate(x_seq):

List to store layer outputs

495            layer_outputs = [x]

Stack of keys and values

498            key_tensor = None
499            value_tensor = None

Get the keys and values tensors if we are beyond the initial step

501            if step > 0:
502                key_tensor = self.mem_key.get()
503                value_tensor = self.mem_value.get()

Run through each layer

506            for layer in self.layers:

Get layer output

508                x = layer(x=x, key=key_tensor, value=value_tensor)

Append them to the list of layer outputs

510                layer_outputs.append(x)

Stack the layer outputs to a tensor

513            layer_outputs = torch.stack(layer_outputs)

Calculate the memory vector as a weighted sum of layer outputs

515            mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))

Calculate the keys from memory and add it to the stack

517            self.mem_key.append(step, self.key(mem))

Calculate the values from memory and add it to the stack

519            self.mem_value.append(step, self.value(mem))

Append the output to results

521            res.append(x)

Stack the output tensors

524        res = torch.stack(res)

Normalize the output

526        return self.norm(res)
528    def free(self):
529        self.mem_key.free()
530        self.mem_value.free()