We have an experiment that trains a HyperLSTM to predict text on Shakespeare dataset. Here's the link to code:
HyperNetworks use a smaller network to generate weights of a larger network. There are two variants: static hyper-networks and dynamic hyper-networks. Static HyperNetworks have smaller networks that generate weights (kernels) of a convolutional network. Dynamic HyperNetworks generate parameters of a recurrent neural network for each step. This is an implementation of the latter.
In a RNN the parameters stay constant for each step. Dynamic HyperNetworks generate different parameters for each step. HyperLSTM has the structure of a LSTM but the parameters of each step are changed by a smaller LSTM network.
In the basic form, a Dynamic HyperNetwork has a smaller recurrent network that generates a feature vector corresponding to each parameter tensor of the larger recurrent network. Let's say the larger network has some parameter the smaller network generates a feature vector and we dynamically compute as a linear transformation of . For instance where is a 3-d tensor parameter and is a tensor-vector multiplication. is usually a linear transformation of the output of the smaller recurrent network.
Large recurrent networks have large dynamically computed parameters. These are calculated using linear transformation of feature vector . And this transformation requires an even larger weight tensor. That is, when has shape , will be .
To overcome this, we compute the weight parameters of the recurrent network by dynamically scaling each row of a matrix of same size.
where is a parameter matrix.
We can further optimize this when we compute , as where stands for element-wise multiplication.
72from typing import Optional, Tuple 73 74import torch 75from torch import nn 76 77from labml_helpers.module import Module 78from labml_nn.lstm import LSTMCell
For HyperLSTM the smaller network and the larger network both have the LSTM structure. This is defined in Appendix A.2.2 in the paper.
is the size of the input ,
is the size of the LSTM, and
is the size of the smaller LSTM that alters the weights of the larger outer LSTM.
is the size of the feature vectors used to alter the LSTM weights.
We use the output of the smaller LSTM to compute , and using linear transformations. We calculate , , and from these, using linear transformations again. These are then used to scale the rows of weight and bias tensors of the main LSTM.
📝 Since the computation of and are two sequential linear transformations these can be combined into a single linear transformation. However we've implemented this separately so that it matches with the description in the paper.
89 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
The input to the hyperLSTM is where is the input and is the output of the outer LSTM at previous step. So the input size is
hidden_size + input_size
The output of hyperLSTM is and .
120 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)
🤔 In the paper it was specified as I feel that it's a typo.
126 self.z_h = nn.Linear(hyper_size, 4 * n_z)
128 self.z_x = nn.Linear(hyper_size, 4 * n_z)
130 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)
133 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] 134 self.d_h = nn.ModuleList(d_h)
136 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)] 137 self.d_x = nn.ModuleList(d_x)
139 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)] 140 self.d_b = nn.ModuleList(d_b)
The weight matrices
143 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
The weight matrices
145 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
148 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)]) 149 self.layer_norm_c = nn.LayerNorm(hidden_size)
151 def forward(self, x: torch.Tensor, 152 h: torch.Tensor, c: torch.Tensor, 153 h_hat: torch.Tensor, c_hat: torch.Tensor):
160 x_hat = torch.cat((h, x), dim=-1)
162 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
165 z_h = self.z_h(h_hat).chunk(4, dim=-1)
167 z_x = self.z_x(h_hat).chunk(4, dim=-1)
169 z_b = self.z_b(h_hat).chunk(4, dim=-1)
We calculate , , and in a loop
172 ifgo =  173 for i in range(4):
175 d_h = self.d_h[i](z_h[i])
177 d_x = self.d_x[i](z_x[i])
184 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \ 185 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \ 186 self.d_b[i](z_b[i]) 187 188 ifgo.append(self.layer_norm[i](y))
191 i, f, g, o = ifgo
194 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
197 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next)) 198 199 return h_next, c_next, h_hat, c_hat
Create a network of
207 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
Store sizes to initialize state
215 self.n_layers = n_layers 216 self.hidden_size = hidden_size 217 self.hyper_size = hyper_size
Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below
221 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] + 222 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in 223 range(n_layers - 1)])
[n_steps, batch_size, input_size]and
stateis a tuple of . have shape
[batch_size, hidden_size]and have shape
225 def forward(self, x: torch.Tensor, 226 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
233 n_steps, batch_size = x.shape[:2]
Initialize the state with zeros if
236 if state is None: 237 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 238 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)] 239 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)] 240 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
242 else: 243 (h, c, h_hat, c_hat) = state
Reverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug
247 h, c = list(torch.unbind(h)), list(torch.unbind(c)) 248 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
Collect the outputs of the final layer at each step
251 out =  252 for t in range(n_steps):
Input to the first layer is the input itself
254 inp = x[t]
Loop through the layers
256 for layer in range(self.n_layers):
Get the state of the layer
258 h[layer], c[layer], h_hat[layer], c_hat[layer] = \ 259 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
Input to the next layer is the state of this layer
261 inp = h[layer]
Collect the output of the final layer
Stack the outputs and states
266 out = torch.stack(out) 267 h = torch.stack(h) 268 c = torch.stack(c) 269 h_hat = torch.stack(h_hat) 270 c_hat = torch.stack(c_hat)
273 return out, (h, c, h_hat, c_hat)