Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.
However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.
21import copy
22
23import torch
24import torch.nn as nn
25
26from labml import experiment
27from labml.configs import option
28from labml_helpers.module import Module
29from labml_nn.activations.fta import FTA
30from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
31from labml_nn.transformers import MultiHeadAttention, TransformerLayer
32from labml_nn.transformers.utils import subsequent_mask
d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN activation
is FTA activation module dropout
is dropout probability for the hidden layer40 def __init__(self, d_model: int, d_ff: int,
41 activation: FTA,
42 dropout: float = 0.1):
49 super().__init__()
Layer one parameterized by weight and bias
51 self.layer1 = nn.Linear(d_model, d_ff)
Layer two parameterized by weight and bias
53 self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)
Hidden layer dropout
55 self.dropout = nn.Dropout(dropout)
Activation function
57 self.activation = activation
59 def forward(self, x: torch.Tensor):
61 x = self.activation(self.layer1(x))
Apply dropout
63 x = self.dropout(x)
65 return self.layer2(x)
This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).
68class AutoregressiveTransformer(Module):
n_tokens
is the number of tokens in the vocabulary d_model
is the embedding size n_layers
is the number of transformer layers layer
is the layer. We use n_layers
copies of this for the transformer.76 def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
83 super().__init__()
Transformer with n_layers
layers
85 self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
Token embedding layer
88 self.emb = nn.Embedding(n_tokens, d_model)
Readout layer
90 self.readout = nn.Linear(d_model, n_tokens)
The mask will be initialized on the first call
93 self.mask = None
x
are the input tokens of shape [seq_len, batch_size]
95 def forward(self, x: torch.Tensor):
Create auto-regressive mask
100 if self.mask is None or self.mask.size(0) != len(x):
Subsequent mask, will mask out tokens from seeing future tokens
102 self.mask = subsequent_mask(len(x)).to(x.device)
Get the token embeddings
105 x = self.emb(x)
Transformer encoder
107 for layer in self.transformer_layers:
108 x = layer(x=x, mask=self.mask)
Get logits
110 x = self.readout(x)
Return results
113 return x, None
116class Configs(NLPAutoRegressionConfigs):
Model
125 model: AutoregressiveTransformer
Number of layers
128 n_layers: int = 4
and for DeepNorm
131 deep_norm_alpha: float
132 deep_norm_beta: float
Number of heads in the attention
135 n_heads: int = 4
Embedding size
137 d_model: int = 256
Size of each attention head
139 d_k: int = 16
Feed forward layer size
141 d_ff: int = 256
FTA
144 fta_lower_limit: float = -1.
145 fta_upper_limit: float = +1.
146 fta_delta: float = 0.2
147 fta_eta: float = 0.05
150@option(Configs.model)
151def _model(c: Configs):
Create FTA activation module
157 fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)
Create the transformer. We re-use TransformerLayer
and MultiHeadAttention
implementations.
161 m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
162 TransformerLayer(d_model=c.d_model,
163 feed_forward=FeedForwardFTA(d_model=c.d_model,
164 d_ff=c.d_ff,
165 activation=fta,
166 dropout=0.1),
167 self_attn=MultiHeadAttention(c.n_heads, c.d_model,
168 dropout_prob=0.0),
169 dropout_prob=0.0))
Move to the device
172 return m.to(c.device)
175def main():
Create experiment
180 experiment.create(name="fta", writers={'screen', 'labml'})
Create configs
182 conf = Configs()
Override configurations
184 experiment.configs(conf, {
Use character level tokenizer
186 'tokenizer': 'character',
Prompt separator is blank
188 'prompt_separator': '',
Starting prompt for sampling
190 'prompt': 'It is ',
Use Tiny Shakespeare dataset
192 'text': 'tiny_shakespeare',
Use a context size of
195 'seq_len': 256,
Train for 32 epochs
197 'epochs': 32,
Batch size
199 'batch_size': 16,
Switch between training and validation for times per epoch
201 'inner_iterations': 10,
Adam optimizer with no warmup
204 'optimizer.optimizer': 'Adam',
205 'optimizer.learning_rate': 3e-4,
206 })
Set model(s) for saving and loading
209 experiment.add_pytorch_models({'model': conf.model})
Start the experiment
212 with experiment.start():
Run training
214 conf.run()
218if __name__ == '__main__':
219 main()