我们已经实现了论文 Hyper Networks 中介绍的 Hyper LSTM,并使用 PyTorch 进行了注释。David Ha的这篇博客文章很好地解释了HyperNetworks。
我们有一个实验可以训练 HyperLSTM 来预测莎士比亚数据集上的文本。以下是代码链接:experiment.py
HyperNetworks 使用较小的网络来生成较大网络的权重。有两种变体:静态超网络和动态超网络。静态超网络具有较小的网络,用于生成卷积网络的权重(内核)。动态超网络为每个步骤生成循环神经网络的参数。这是后者的实现。
在 RNN 中,每个步骤的参数保持不变。动态超网络为每个步骤生成不同的参数。HyperLSTM 具有 LSTM 的结构,但每个步骤的参数都由较小的 LSTM 网络更改。
在基本形式中,Dynamic HyperNetwork 具有较小的循环网络,该网络生成与较大循环网络的每个参数张量对应的特征向量。假设较大的网络有一些参数,较小的网络生成一个特征向量,我们动态计算为的线性变换。例如,其中是三维张量参数,是张量向量乘法。通常是较小的循环网络输出的线性变换。
大型循环网络具有大量的动态计算参数。这些是使用特征向量的线性变换计算的。而且这种变换需要更大的权重张量。也就是说,当有形状时,将是。
为了克服这个问题,我们通过动态缩放相同大小的矩阵的每一行来计算循环网络的权重参数。
其中是参数矩阵。
我们可以在计算时进一步对其进行优化,因为其中代表逐元素乘法。
72from typing import Optional, Tuple
73
74import torch
75from torch import nn
76
77from labml_helpers.module import Module
78from labml_nn.lstm import LSTMCell
81class HyperLSTMCell(Module):
input_size
是输入的大小,hidden_size
是 LSTM 的大小,hyper_size
是较小的 LSTM 的大小,它会改变更大的外部 LSTM。n_z
是用于改变 LSTM 权重的特征向量的大小。
我们使用较小的 LSTM 的输出进行计算,并使用线性变换。我们再次使用线性变换进行计算、和计算。然后使用它们来缩放主 LSTM 的权重和偏置张量的行。
📝 由于和的计算是两个连续的线性变换,因此可以将它们组合成单个线性变换。但是,我们已经单独实现了这一点,以便它与论文中的描述相匹配。
89 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int):
107 super().__init__()
120 self.hyper = LSTMCell(hidden_size + input_size, hyper_size, layer_norm=True)
🤔 在报纸上指定了它,因为我觉得这是一个错字。
126 self.z_h = nn.Linear(hyper_size, 4 * n_z)
128 self.z_x = nn.Linear(hyper_size, 4 * n_z)
130 self.z_b = nn.Linear(hyper_size, 4 * n_z, bias=False)
133 d_h = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
134 self.d_h = nn.ModuleList(d_h)
136 d_x = [nn.Linear(n_z, hidden_size, bias=False) for _ in range(4)]
137 self.d_x = nn.ModuleList(d_x)
139 d_b = [nn.Linear(n_z, hidden_size) for _ in range(4)]
140 self.d_b = nn.ModuleList(d_b)
权重矩阵
143 self.w_h = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, hidden_size)) for _ in range(4)])
权重矩阵
145 self.w_x = nn.ParameterList([nn.Parameter(torch.zeros(hidden_size, input_size)) for _ in range(4)])
层规范化
148 self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
149 self.layer_norm_c = nn.LayerNorm(hidden_size)
151 def forward(self, x: torch.Tensor,
152 h: torch.Tensor, c: torch.Tensor,
153 h_hat: torch.Tensor, c_hat: torch.Tensor):
160 x_hat = torch.cat((h, x), dim=-1)
162 h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
165 z_h = self.z_h(h_hat).chunk(4, dim=-1)
167 z_x = self.z_x(h_hat).chunk(4, dim=-1)
169 z_b = self.z_b(h_hat).chunk(4, dim=-1)
我们循环计算、和
172 ifgo = []
173 for i in range(4):
175 d_h = self.d_h[i](z_h[i])
177 d_x = self.d_x[i](z_x[i])
184 y = d_h * torch.einsum('ij,bj->bi', self.w_h[i], h) + \
185 d_x * torch.einsum('ij,bj->bi', self.w_x[i], x) + \
186 self.d_b[i](z_b[i])
187
188 ifgo.append(self.layer_norm[i](y))
191 i, f, g, o = ifgo
194 c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
197 h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
198
199 return h_next, c_next, h_hat, c_hat
202class HyperLSTM(Module):
创建一个由 HyperLSTMn_layers
组成的网络。
207 def __init__(self, input_size: int, hidden_size: int, hyper_size: int, n_z: int, n_layers: int):
212 super().__init__()
存储大小以初始化状态
215 self.n_layers = n_layers
216 self.hidden_size = hidden_size
217 self.hyper_size = hyper_size
为每层创建单元。请注意,只有第一层直接获得输入。其余图层从下面的图层获取输入
221 self.cells = nn.ModuleList([HyperLSTMCell(input_size, hidden_size, hyper_size, n_z)] +
222 [HyperLSTMCell(hidden_size, hidden_size, hyper_size, n_z) for _ in
223 range(n_layers - 1)])
x
有形状[n_steps, batch_size, input_size]
和state
是的元组。有形状[batch_size, hidden_size]
和形状[batch_size, hyper_size]
。225 def forward(self, x: torch.Tensor,
226 state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = None):
233 n_steps, batch_size = x.shape[:2]
使用零初始化状态如果None
236 if state is None:
237 h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
238 c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
239 h_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
240 c_hat = [x.new_zeros(batch_size, self.hyper_size) for _ in range(self.n_layers)]
242 else:
243 (h, c, h_hat, c_hat) = state
247 h, c = list(torch.unbind(h)), list(torch.unbind(c))
248 h_hat, c_hat = list(torch.unbind(h_hat)), list(torch.unbind(c_hat))
在每一步收集最后一层的输出
251 out = []
252 for t in range(n_steps):
第一层的输入是输入本身
254 inp = x[t]
循环穿过图层
256 for layer in range(self.n_layers):
获取图层的状态
258 h[layer], c[layer], h_hat[layer], c_hat[layer] = \
259 self.cells[layer](inp, h[layer], c[layer], h_hat[layer], c_hat[layer])
下一层的输入是该图层的状态
261 inp = h[layer]
收集最后一层的输出
263 out.append(h[-1])
堆叠输出和状态
266 out = torch.stack(out)
267 h = torch.stack(h)
268 c = torch.stack(c)
269 h_hat = torch.stack(h_hat)
270 c_hat = torch.stack(c_hat)
273 return out, (h, c, h_hat, c_hat)