Long Short-Term Memory (LSTM)

This is a PyTorch implementation of Long Short-Term Memory.

12from typing import Optional, Tuple
13
14import torch
15from torch import nn
16
17from labml_helpers.module import Module

Long Short-Term Memory Cell

LSTM Cell computes $c$, and $h$. $c$ is like the long-term memory, and $h$ is like the short term memory. We use the input $x$ and $h$ to update the long term memory. In the update, some features of $c$ are cleared with a forget gate $f$, and some features $i$ are added through a gate $g$.

The new short term memory is the $\tanh$ of the long-term memory multiplied by the output gate $o$.

Note that the cell doesn’t look at long term memory $c$ when doing the update. It only modifies it. Also $c$ never goes through a linear transformation. This is what solves vanishing and exploding gradients.

Here’s the update rule.

$\odot$ stands for element-wise multiplication.

Intermediate values and gates are computed as linear transformations of the hidden state and input.

20class LSTMCell(Module):
58    def __init__(self, input_size: int, hidden_size: int, layer_norm: bool = False):
59        super().__init__()

These are the linear layer to transform the input and hidden vectors. One of them doesn’t need a bias since we add the transformations.

This combines $lin_x^i$, $lin_x^f$, $lin_x^g$, and $lin_x^o$ transformations.

65        self.hidden_lin = nn.Linear(hidden_size, 4 * hidden_size)

This combines $lin_h^i$, $lin_h^f$, $lin_h^g$, and $lin_h^o$ transformations.

67        self.input_lin = nn.Linear(input_size, 4 * hidden_size, bias=False)

Whether to apply layer normalizations.

Applying layer normalization gives better results. $i$, $f$, $g$ and $o$ embeddings are normalized and $c_t$ is normalized in $h_t = o_t \odot \tanh(\mathop{LN}(c_t))$

74        if layer_norm:
75            self.layer_norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(4)])
76            self.layer_norm_c = nn.LayerNorm(hidden_size)
77        else:
78            self.layer_norm = nn.ModuleList([nn.Identity() for _ in range(4)])
79            self.layer_norm_c = nn.Identity()
81    def __call__(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):

We compute the linear transformations for $i_t$, $f_t$, $g_t$ and $o_t$ using the same linear layers.

84        ifgo = self.hidden_lin(h) + self.input_lin(x)

Each layer produces an output of 4 times the hidden_size and we split them

86        ifgo = ifgo.chunk(4, dim=-1)

Apply layer normalization (not in original paper, but gives better results)

89        ifgo = [self.layer_norm[i](ifgo[i]) for i in range(4)]

92        i, f, g, o = ifgo

95        c_next = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)

Optionally, apply layer norm to $c_t$

99        h_next = torch.sigmoid(o) * torch.tanh(self.layer_norm_c(c_next))
100
101        return h_next, c_next

Multilayer LSTM

104class LSTM(Module):

Create a network of n_layers of LSTM.

109    def __init__(self, input_size: int, hidden_size: int, n_layers: int):
114        super().__init__()
115        self.n_layers = n_layers
116        self.hidden_size = hidden_size

Create cells for each layer. Note that only the first layer gets the input directly. Rest of the layers get the input from the layer below

119        self.cells = nn.ModuleList([LSTMCell(input_size, hidden_size)] +
120                                   [LSTMCell(hidden_size, hidden_size) for _ in range(n_layers - 1)])

x has shape [n_steps, batch_size, input_size] and state is a tuple of $h$ and $c$, each with a shape of [batch_size, hidden_size].

122    def __call__(self, x: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
127        n_steps, batch_size = x.shape[:2]

Initialize the state if None

130        if state is None:
131            h = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
132            c = [x.new_zeros(batch_size, self.hidden_size) for _ in range(self.n_layers)]
133        else:
134            (h, c) = state

Reverse stack the tensors to get the states of each layer
📝 You can just work with the tensor itself but this is easier to debug

137            h, c = list(torch.unbind(h)), list(torch.unbind(c))

Array to collect the outputs of the final layer at each time step.

140        out = []
141        for t in range(n_steps):

Input to the first layer is the input itself

143            inp = x[t]

Loop through the layers

145            for layer in range(self.n_layers):

Get the state of the layer

147                h[layer], c[layer] = self.cells[layer](inp, h[layer], c[layer])

Input to the next layer is the state of this layer

149                inp = h[layer]

Collect the output $h$ of the final layer

151            out.append(h[-1])

Stack the outputs and states

154        out = torch.stack(out)
155        h = torch.stack(h)
156        c = torch.stack(c)
157
158        return out, (h, c)