9from typing import Optional
10
11import torch
12from torch import nn
13
14from labml_nn.transformers.fast_weights import DPFP
15from labml_nn.transformers.feed_forward import FeedForward
16from labml_nn.transformers.mha import PrepareForMultiHeadAttention
17from labml_nn.utils import clone_module_list20class FastWeightsAttention(nn.Module):21 def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
22 super().__init__()Number of features per head
25 self.d_k = d_model // heads27 self.heads = headsThese transform the query
multi-headed attention.
30 self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)These transform the key
and value
for multi-headed attention.
32 self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
33 self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=False)
34
35 self.gate = nn.Sequential(PrepareForMultiHeadAttention(d_model, heads, 1, bias=False),
36 nn.Sigmoid())
37
38 self.phi = phiOutput layer
41 self.output = nn.Linear(d_model, d_model)Dropout
43 self.dropout = nn.Dropout(dropout_prob)45 def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
46 query = self.phi(self.query(x))
47 key = self.phi(self.key(x))
48 value = self.value(x)
49
50 if weights is None:
51 weights = key.new_zeros((key.shape[0], key.shape[1], value.shape[2], key.shape[2]))
52
53 value_existing = torch.einsum('bhvk,bhk->bhv', weights, key)
54
55 beta = self.gate(x)
56
57 weights = weights + torch.einsum('bhv,bhk->bhvk', beta * (value - value_existing), key)
58
59 x = torch.einsum('bhvk,bhk->bhv', weights, query)Concatenate multiple heads
62 x = x.reshape(x.shape[0], -1)Output layer
65 return self.output(x), weights68class FastWeightsAttentionTransformerLayer(nn.Module):69 def __init__(self, *,
70 d_model: int,
71 attn: FastWeightsAttention,
72 feed_forward: FeedForward,
73 dropout_prob: float):
74 super().__init__()Transformer size
76 self.size = d_model78 self.attn = attn
79 self.feed_forward = feed_forward
80 self.dropout = nn.Dropout(dropout_prob)Normalization layers
83 self.norm_self_attn = nn.LayerNorm([d_model])
84 self.norm_ff = nn.LayerNorm([d_model])86 def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
87 attn, weights = self.attn(x, weights)Add the self attention results
89 x = x + self.dropout(attn)Normalize for feed-forward
92 z = self.norm_ff(x)Pass through the feed-forward network
94 ff = self.feed_forward(z)Add the feed-forward results back
96 x = x + self.dropout(ff)99 return x, weights102class FastWeightsAttentionTransformer(nn.Module):103 def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
104 super().__init__()Make copies of the transformer layer
106 self.layers = clone_module_list(layer, n_layers)Final normalization layer
108 self.norm = nn.LayerNorm([layer.size])110 def forward(self, x_seq: torch.Tensor):Split the input to a list along the sequence axis
112 x_seq = torch.unbind(x_seq, dim=0)List to store the outputs
114 res = []For each input step
116 weights = [None for _ in range(len(self.layers))]
117
118 for x in x_seq:Run through each layer
120 for i, layer in enumerate(self.layers):Get layer output
122 x, weights[i] = layer(x, weights[i])
123
124 res.append(x)Stack the output tensors
127 res = torch.stack(res)Normalize the output
129 return self.norm(res)