This is a PyTorch implementation of position-wise feedforward network used in transformer.
FFN consists of two fully connected layers. Number of dimensions in the hidden layer , is generally set to around four times that of the token embedding . So it is sometime also called the expand-and-contract network.
There is an activation at the hidden layer, which is usually set to ReLU (Rectified Linear Unit) activation,
That is, the FFN function is, where , , and are learnable parameters.
Sometimes the GELU (Gaussian Error Linear Unit) activation is also used instead of ReLU. where
This is a generic implementation that supports different variants including Gated Linear Units (GLU). We have also implemented experiments on these:
38import torch
39from torch import nn as nn
40
41from labml_helpers.module import Module
44class FeedForward(Module):
d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN dropout
is dropout probability for the hidden layer is_gated
specifies whether the hidden layer is gated bias1
specified whether the first fully connected layer should have a learnable bias bias2
specified whether the second fully connected layer should have a learnable bias bias_gate
specified whether the fully connected layer for the gate should have a learnable bias49 def __init__(self, d_model: int, d_ff: int,
50 dropout: float = 0.1,
51 activation=nn.ReLU(),
52 is_gated: bool = False,
53 bias1: bool = True,
54 bias2: bool = True,
55 bias_gate: bool = True):
65 super().__init__()
Layer one parameterized by weight and bias
67 self.layer1 = nn.Linear(d_model, d_ff, bias=bias1)
Layer one parameterized by weight and bias
69 self.layer2 = nn.Linear(d_ff, d_model, bias=bias2)
Hidden layer dropout
71 self.dropout = nn.Dropout(dropout)
Activation function
73 self.activation = activation
Whether there is a gate
75 self.is_gated = is_gated
76 if is_gated:
If there is a gate the linear layer to transform inputs to be multiplied by the gate, parameterized by weight and bias
79 self.linear_v = nn.Linear(d_model, d_ff, bias=bias_gate)
81 def forward(self, x: torch.Tensor):
83 g = self.activation(self.layer1(x))
If gated,
85 if self.is_gated:
86 x = g * self.linear_v(x)
Otherwise
88 else:
89 x = g
Apply dropout
91 x = self.dropout(x)
or depending on whether it is gated
94 return self.layer2(x)