This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.
Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.
In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.
The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.
The updated feedback transformer shares weights and used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.
Here's the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.
43import math
44from typing import Optional
45
46import torch
47from torch import nn
48
49from labml_helpers.module import Module
50from labml_nn.transformers.feed_forward import FeedForward
51from labml_nn.transformers.mha import PrepareForMultiHeadAttention
52from labml_nn.utils import clone_module_list
This module computes recurrent attention similar to attention from original transformers paper.
55class FeedbackAttention(Module):
d_model
is the number of features in the transformer dropout_prob
is the attention dropout probability is_kv_precomputed
is whether key, value tensors are already calculated66 def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, *,
67 is_kv_precomputed: bool = False):
75 super().__init__()
Number of features per head
78 self.d_k = d_model // heads
80 self.heads = heads
These transform the query
multi-headed attention.
83 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
These transform the key
and value
for multi-headed attention.
85 if not is_kv_precomputed:
86 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
87 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
Keys and values are already calculated
89 else:
90 self.key = None
91 self.value = None
Output layer
94 self.output = nn.Linear(d_model, d_model)
Dropout
96 self.dropout = nn.Dropout(dropout_prob)
Scaling factor before the softmax
98 self.scale = 1 / math.sqrt(self.d_k)
Softmax for attention along the time dimension of key
101 self.softmax = nn.Softmax(dim=0)
Number of relative positions
104 self.P = 2 ** 12
Relative positional embeddings for key relative to the query.
107 self.key_pos_embeddings = nn.Parameter(torch.zeros((self.P, heads, self.d_k)), requires_grad=True)
Relative positional embedding bias for key relative to the query.
109 self.key_pos_bias = nn.Parameter(torch.zeros((self.P, heads)), requires_grad=True)
Positional embeddings for the query is independent of the position of the query
111 self.query_pos_bias = nn.Parameter(torch.zeros((heads, self.d_k)), requires_grad=True)
We store attentions so that it can be used for logging, or other computations if needed
114 self.attn = None
We use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.
Attention from current step's query to key in step (relative to current step) is,
where , are linear transformations of original embeddings and are linear transformations of positional encodings .
We replace term with .
116 def get_scores(self, query: torch.Tensor, key: torch.Tensor):
144 key_pos_emb = self.key_pos_embeddings[-key.shape[0]:]
146 query_pos_bias = self.query_pos_bias[None, :, :]
148 key_pos_bias = self.key_pos_bias[-key.shape[0]:]
151 ac = torch.einsum('bhd,jbhd->jbh', query + query_pos_bias, key)
153 bd = torch.einsum('bhd,jhd->jbh', query, key_pos_emb) + key_pos_bias[:, None, :]
156 return ac + bd
query
has shape [batch_size, d_model]
key
and value
has shape [seq_len, batch_size, d_model]
158 def forward(self, *,
159 query: torch.Tensor,
160 key: torch.Tensor,
161 value: torch.Tensor):
Prepare query
, key
and value
for attention computation key
and value
will then have shape [seq_len, batch_size, heads, d_k]
and query
will have shape [batch_size, heads, d_k]
170 query = self.query(query)
171 if self.key:
172 key = self.key(key)
173 if self.value:
174 value = self.value(value)
Compute attention scores. Results in a tensor of shape [seq_len, batch_size, heads]
178 scores = self.get_scores(query, key)
Scale scores
181 scores *= self.scale
Softmax
184 attn = self.softmax(scores)
Apply dropout
187 attn = self.dropout(attn)
Multiply by the values
190 x = torch.einsum("jbh,jbhd->bhd", attn, value)
Concatenate multiple heads
193 x = x.reshape(x.shape[0], -1)
Output layer
196 return self.output(x)
This implements a single transformer layer in the feedback transformer.
199class FeedbackTransformerLayer(Module):
d_model
is the number of features in the transformer attn
is the feedback attention module feed_forward
is the position-wise feed forward layer dropout_prob
is the dropout probability for dropout layers after attention and feed-forward206 def __init__(self, *,
207 d_model: int,
208 attn: FeedbackAttention,
209 feed_forward: FeedForward,
210 dropout_prob: float):
217 super().__init__()
Transformer size
219 self.size = d_model
221 self.attn = attn
222 self.feed_forward = feed_forward
223 self.dropout = nn.Dropout(dropout_prob)
Normalization layers
226 self.norm_self_attn = nn.LayerNorm([d_model])
227 self.norm_ff = nn.LayerNorm([d_model])
229 def forward(self, *,
230 x: torch.Tensor,
231 key: Optional[torch.Tensor],
232 value: Optional[torch.Tensor]):
If there is memory
234 if key is not None:
Normalize the vectors before doing self attention
236 z = self.norm_self_attn(x)
Run through self attention, i.e. keys and values are from self
238 self_attn = self.attn(query=z, key=key, value=value)
Add the self attention results
240 x = x + self.dropout(self_attn)
Normalize for feed-forward
243 z = self.norm_ff(x)
Pass through the feed-forward network
245 ff = self.feed_forward(z)
Add the feed-forward results back
247 x = x + self.dropout(ff)
250 return x
253class FeedbackTransformer(Module):
layer
is the feedback transformer layer, which we clone for each layer n_layers
is the number of layers in the transformer258 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int):
264 super().__init__()
Make copies of the transformer layer
266 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
268 self.norm = nn.LayerNorm([layer.size])
Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
271 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
Softmax for weights before taking the weighted sum
273 self.softmax = nn.Softmax(0)
x_seq
is the input with shape [seq_len, batch_size, d_model]
275 def forward(self, x_seq: torch.Tensor):
Split the input to a list along the sequence axis
281 x_seq = torch.unbind(x_seq, dim=0)
List to store the outputs
283 res = []
List to store the memory vectors
285 mem = []
For each input step
287 for x in x_seq:
List to store layer outputs
289 layer_outputs = [x]
If there is memory, stack them into a vector
292 mem_tensor = torch.stack(mem) if mem else None
Run through each layer
295 for layer in self.layers:
Get layer output
297 x = layer(x=x, key=mem_tensor, value=mem_tensor)
Append them to the list of layer outputs
299 layer_outputs.append(x)
Stack the layer outputs to a tensor
302 layer_outputs = torch.stack(layer_outputs)
Calculate the memory vector as a weighted sum of layer outputs
304 mem.append(torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights)))
Append the output to results
306 res.append(x)
Stack the output tensors
309 res = torch.stack(res)
Normalize the output
311 return self.norm(res)
We implement a custom function instead of appending to a python list and then doing torch.stack
. This greatly improves the performance over calling torch.stack
at each step along the sequence. Everytime torch.stack
is called, it creates a new tensor, while this method and the accompanying class Stack
share memory for each step.
318class StackFunction(torch.autograd.Function):
ctx
is the context of the function (which lets us cache stuff) memory
is the shared memory tensor where we stack and store the values of each step (keys & values) memory_grad
is the shared memory tensor to store and accumulate gradients of each step last
is the last value stacked n
is the number of steps (i.e. size of the stack)This returns the stacked tensor for steps upto n
.
330 @staticmethod
331 def forward(ctx, memory, memory_grad, last, n):
Cache accumulated gradients
343 ctx._mem_grad = memory_grad
Cache the size of the stack
345 ctx._n = n
Return the stack
347 return memory[:n + 1]
grad_output
is the gradient with respect to the output of about forward
functionThis accumulates the gradients in the shared memory tensor and return the gradients with respect to the last
result in the stack.
349 @staticmethod
350 def backward(ctx, grad_output):
Get the current size of the stack
358 n = ctx._n
Get the accumulated gradients
360 memory_grad = ctx._mem_grad
Add the gradients
362 memory_grad[:n + 1] += grad_output
Return the gradients w.r.t to last value in the stack
364 return None, None, memory_grad[n], None
367class Stack:
max_len
is the maximum size of the stack374 def __init__(self, max_len: int):
378 self.max_len = max_len
379 self.memory = None
380 self.memory_grad = None
381 self.last = None
382 self.n = -1
383 self.last_get_n = -1
n
is the size of the stack value
is the tensor that needs to be added to the stack385 def append(self, n: int, value: torch.Tensor):
You need to get (use) the stack after adding a value. Otherwise this implementation fails
393 assert n == 0 or self.last_get_n == n - 1, f"{n}, {self.last_get_n}"
Do this without gradients
396 with torch.no_grad():
Initialize the shared memory tensor to keep the stack
398 if self.memory is None or self.memory.shape[1:] != value.shape:
This should only happen when the stack is empty
400 assert n == 0
Create a tensor for the stack
402 self.memory = value.new_zeros(self.max_len, *value.shape, requires_grad=False)
Create a tensor to accumulate the gradients
404 self.memory_grad = value.new_zeros(self.memory.shape, requires_grad=False)
The memory is already initialized but we are resetting the stack.
This could have been another function like reset
, but we found this easier to use.
409 elif n == 0:
Reset accumulated gradients
411 self.memory_grad.fill_(0.)
Set the value in the correct position of the stack
414 self.memory.data[n] = value.detach()
Keep track of the stack (for debugging)
416 self.n = n
Keep track of the last value added to the stack. We need this to be passed on to StackFunction
in order to get the gradients propagated backwards.
421 self.last = value
Returns the stack
423 def get(self):
Keep track of the size of the stack when it was used. This is used for a sanity check in append
.
430 self.last_get_n = self.n
Take it all through StackFunction
so that StackFunction.backwards
is called by PyTorch during backpropagation.
433 return StackFunction.apply(self.memory, self.memory_grad, self.last, self.n)
To release memory
435 def free(self):
440 self.memory = None
441 self.memory_grad = None
442 self.last = None
This is the updated feedback transformer module that caches the keys and values.
445class FeedbackTransformerKV(Module):
layer
is the feedback transformer layer, which we clone for each layer n_layers
is the number of layers in the transformer d_model
is the number of features in the transformer 452 def __init__(self, layer: FeedbackTransformerLayer, n_layers: int, d_model: int, heads: int):
460 super().__init__()
Make copies of the transformer layer
462 self.layers = clone_module_list(layer, n_layers)
Final normalization layer
464 self.norm = nn.LayerNorm([layer.size])
Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.
467 self.weights = nn.Parameter(torch.ones(n_layers + 1), requires_grad=True)
Softmax for weights before taking the weighted sum
469 self.softmax = nn.Softmax(0)
Number of features in a head
472 d_k = d_model // heads
Module to transform embeddings (memory) to get keys
474 self.key = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
Module to transform embeddings (memory) to get keys
476 self.value = PrepareForMultiHeadAttention(d_model, heads, d_k, bias=False)
Memory for stacked keys
479 self.mem_key = Stack(512)
Memory for stacked values
481 self.mem_value = Stack(512)
x_seq
is the input with shape [seq_len, batch_size, d_model]
483 def forward(self, x_seq: torch.Tensor):
Split the input to a list along the sequence axis
489 x_seq = torch.unbind(x_seq, dim=0)
List to store the outputs
491 res = []
For each input step
493 for step, x in enumerate(x_seq):
List to store layer outputs
495 layer_outputs = [x]
Stack of keys and values
498 key_tensor = None
499 value_tensor = None
Get the keys and values tensors if we are beyond the initial step
501 if step > 0:
502 key_tensor = self.mem_key.get()
503 value_tensor = self.mem_value.get()
Run through each layer
506 for layer in self.layers:
Get layer output
508 x = layer(x=x, key=key_tensor, value=value_tensor)
Append them to the list of layer outputs
510 layer_outputs.append(x)
Stack the layer outputs to a tensor
513 layer_outputs = torch.stack(layer_outputs)
Calculate the memory vector as a weighted sum of layer outputs
515 mem = torch.einsum('lbd,l->bd', layer_outputs, self.softmax(self.weights))
Calculate the keys from memory and add it to the stack
517 self.mem_key.append(step, self.key(mem))
Calculate the values from memory and add it to the stack
519 self.mem_value.append(step, self.value(mem))
Append the output to results
521 res.append(x)
Stack the output tensors
524 res = torch.stack(res)
Normalize the output
526 return self.norm(res)
528 def free(self):
529 self.mem_key.free()
530 self.mem_value.free()