9from typing import Optional
10
11import torch
12from torch import nn
13
14from labml_nn.transformers.fast_weights import DPFP
15from labml_nn.transformers.feed_forward import FeedForward
16from labml_nn.transformers.mha import PrepareForMultiHeadAttention
17from labml_nn.utils import clone_module_list
20class FastWeightsAttention(nn.Module):
21    def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
22        super().__init__()

Number of features per head

25        self.d_k = d_model // heads

27        self.heads = heads

These transform the query multi-headed attention.

30        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)

These transform the key and value for multi-headed attention.

32        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
33        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
34
35        self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
36                                  nn.Sigmoid())
37
38        self.phi = phi

Output layer

41        self.output = nn.Linear(d_model, d_model)

Dropout

43        self.dropout = nn.Dropout(dropout_prob)
45    def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
46        query = self.phi(self.query(x))
47        key = self.phi(self.key(x))
48        value = self.value(x)
49
50        if weights is None:
51            weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
52
53        value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
54
55        beta = self.gate(x)
56
57        weights = weights + torch.einsum('bhv,bhk->bhvk', beta * (value - value_existing), key)
58
59        x = torch.einsum('bhvk,bhk->bhv', weights, query)

Concatenate multiple heads

62        x = x.reshape(x.shape[0], -1)

Output layer

65        return self.output(x), weights
68class FastWeightsAttentionTransformerLayer(nn.Module):
69    def __init__(self, *,
70                 d_model: int,
71                 attn: FastWeightsAttention,
72                 feed_forward: FeedForward,
73                 dropout_prob: float):
74        super().__init__()

Transformer size

76        self.size = d_model

78        self.attn = attn
79        self.feed_forward = feed_forward
80        self.dropout = nn.Dropout(dropout_prob)

Normalization layers

83        self.norm_self_attn = nn.LayerNorm([d_model])
84        self.norm_ff = nn.LayerNorm([d_model])
86    def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
87        attn, weights = self.attn(x, weights)

Add the self attention results

89        x = x + self.dropout(attn)

Normalize for feed-forward

92        z = self.norm_ff(x)

Pass through the feed-forward network

94        ff = self.feed_forward(z)

Add the feed-forward results back

96        x = x + self.dropout(ff)

99        return x, weights
102class FastWeightsAttentionTransformer(nn.Module):
103    def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
104        super().__init__()

Make copies of the transformer layer

106        self.layers = clone_module_list(layer, n_layers)

Final normalization layer

108        self.norm = nn.LayerNorm([layer.size])
110    def forward(self, x_seq: torch.Tensor):

Split the input to a list along the sequence axis

112        x_seq = torch.unbind(x_seq, dim=0)

List to store the outputs

114        res = []

For each input step

116        weights = [None for _ in range(len(self.layers))]
117
118        for x in x_seq:

Run through each layer

120            for i, layer in enumerate(self.layers):

Get layer output

122                x, weights[i] = layer(x, weights[i])
123
124            res.append(x)

Stack the output tensors

127        res = torch.stack(res)

Normalize the output

129        return self.norm(res)