Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.
However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.
21import copy
22
23import torch
24import torch.nn as nn
25
26from labml import experiment
27from labml.configs import option
28from labml_nn.activations.fta import FTA
29from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
30from labml_nn.transformers import MultiHeadAttention, TransformerLayer
31from labml_nn.transformers.utils import subsequent_mask
d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN activation
is FTA activation module dropout
is dropout probability for the hidden layer39 def __init__(self, d_model: int, d_ff: int,
40 activation: FTA,
41 dropout: float = 0.1):
48 super().__init__()
Layer one parameterized by weight and bias
50 self.layer1 = nn.Linear(d_model, d_ff)
Layer two parameterized by weight and bias
52 self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
Hidden layer dropout
54 self.dropout = nn.Dropout(dropout)
Activation function
56 self.activation = activation
58 def forward(self, x: torch.Tensor):
60 x = self.activation(self.layer1(x))
Apply dropout
62 x = self.dropout(x)
64 return self.layer2(x)
This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).
67class AutoregressiveTransformer(nn.Module):
n_tokens
is the number of tokens in the vocabulary d_model
is the embedding size n_layers
is the number of transformer layers layer
is the layer. We use n_layers
copies of this for the transformer.75 def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
82 super().__init__()
Transformer with n_layers
layers
84 self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
Token embedding layer
87 self.emb = nn.Embedding(n_tokens, d_model)
Readout layer
89 self.readout = nn.Linear(d_model, n_tokens)
The mask will be initialized on the first call
92 self.mask = None
x
are the input tokens of shape [seq_len, batch_size]
94 def forward(self, x: torch.Tensor):
Create auto-regressive mask
99 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
101 self.mask = subsequent_mask(len(x)).to(x.device)
Get the token embeddings
104 x = self.emb(x)
Transformer encoder
106 for layer in self.transformer_layers:
107 x = layer(x=x, mask=self.mask)
Get logits
109 x = self.readout(x)
Return results
112 return x, None
115class Configs(NLPAutoRegressionConfigs):
Model
124 model: AutoregressiveTransformer
Number of layers
127 n_layers: int = 4
and for DeepNorm
130 deep_norm_alpha: float
131 deep_norm_beta: float
Number of heads in the attention
134 n_heads: int = 4
Embedding size
136 d_model: int = 256
Size of each attention head
138 d_k: int = 16
Feed forward layer size
140 d_ff: int = 256
FTA
143 fta_lower_limit: float = -1.
144 fta_upper_limit: float = +1.
145 fta_delta: float = 0.2
146 fta_eta: float = 0.05
149@option(Configs.model)
150def _model(c: Configs):
Create FTA activation module
156 fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
Create the transformer. We re-use TransformerLayer
and MultiHeadAttention
implementations.
160 m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
161 TransformerLayer(d_model=c.d_model,
162 feed_forward=FeedForwardFTA(d_model=c.d_model,
163 d_ff=c.d_ff,
164 activation=fta,
165 dropout=0.1),
166 self_attn=MultiHeadAttention(c.n_heads, c.d_model,
167 dropout_prob=0.0),
168 dropout_prob=0.0))
Move to the device
171 return m.to(c.device)
174def main():
Create experiment
179 experiment.create(name="fta", writers={'screen', 'labml'})
Create configs
181 conf = Configs()
Override configurations
183 experiment.configs(conf, {
Use character level tokenizer
185 'tokenizer': 'character',
Prompt separator is blank
187 'prompt_separator': '',
Starting prompt for sampling
189 'prompt': 'It is ',
Use Tiny Shakespeare dataset
191 'text': 'tiny_shakespeare',
Use a context size of
194 'seq_len': 256,
Train for 32 epochs
196 'epochs': 32,
Batch size
198 'batch_size': 16,
Switch between training and validation for times per epoch
200 'inner_iterations': 10,
Adam optimizer with no warmup
203 'optimizer.optimizer': 'Adam',
204 'optimizer.learning_rate': 3e-4,
205 })
Set model(s) for saving and loading
208 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
211 with experiment.start():
Run training
213 conf.run()
217if __name__ == '__main__':
218 main()