Fuzzy Tiling Activation Experiment

Open In Colab

Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.

However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.

21import copy
22
23import torch
24import torch.nn as nn
25
26from labml import experiment
27from labml.configs import option
28from labml_helpers.module import Module
29from labml_nn.activations.fta import FTA
30from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
31from labml_nn.transformers import MultiHeadAttention, TransformerLayer
32from labml_nn.transformers.utils import subsequent_mask

FFN module with FTA activation

35class FeedForwardFTA(nn.Module):
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • activation is FTA activation module
  • dropout is dropout probability for the hidden layer
40    def __init__(self, d_model: int, d_ff: int,
41                 activation: FTA,
42                 dropout: float = 0.1):
49        super().__init__()

Layer one parameterized by weight and bias

51        self.layer1 = nn.Linear(d_model, d_ff)

Layer two parameterized by weight and bias

53        self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)

Hidden layer dropout

55        self.dropout = nn.Dropout(dropout)

Activation function

57        self.activation = activation
59    def forward(self, x: torch.Tensor):

61        x = self.activation(self.layer1(x))

Apply dropout

63        x = self.dropout(x)

65        return self.layer2(x)

Auto-Regressive model

This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).

68class AutoregressiveTransformer(Module):
  • n_tokens is the number of tokens in the vocabulary
  • d_model is the embedding size
  • n_layers is the number of transformer layers
  • layer is the layer. We use n_layers copies of this for the transformer.
76    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
83        super().__init__()

Transformer with n_layers layers

85        self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])

Token embedding layer

88        self.emb = nn.Embedding(n_tokens, d_model)

Readout layer

90        self.readout = nn.Linear(d_model, n_tokens)

The mask will be initialized on the first call

93        self.mask = None
  • x are the input tokens of shape [seq_len, batch_size]
95    def forward(self, x: torch.Tensor):

Create auto-regressive mask

100        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

102            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings

105        x = self.emb(x)

Transformer encoder

107        for layer in self.transformer_layers:
108            x = layer(x=x, mask=self.mask)

Get logits

110        x = self.readout(x)

Return results

113        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

116class Configs(NLPAutoRegressionConfigs):

Model

125    model: AutoregressiveTransformer

Number of layers

128    n_layers: int = 4

and for DeepNorm

131    deep_norm_alpha: float
132    deep_norm_beta: float

Number of heads in the attention

135    n_heads: int = 4

Embedding size

137    d_model: int = 256

Size of each attention head

139    d_k: int = 16

Feed forward layer size

141    d_ff: int = 256

FTA

144    fta_lower_limit: float = -1.
145    fta_upper_limit: float = +1.
146    fta_delta: float = 0.2
147    fta_eta: float = 0.05

Initialize the model

150@option(Configs.model)
151def _model(c: Configs):

Create FTA activation module

157    fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)

Create the transformer. We re-use TransformerLayer and MultiHeadAttention implementations.

161    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
162                                  TransformerLayer(d_model=c.d_model,
163                                                   feed_forward=FeedForwardFTA(d_model=c.d_model,
164                                                                               d_ff=c.d_ff,
165                                                                               activation=fta,
166                                                                               dropout=0.1),
167                                                   self_attn=MultiHeadAttention(c.n_heads, c.d_model,
168                                                                                dropout_prob=0.0),
169                                                   dropout_prob=0.0))

Move to the device

172    return m.to(c.device)

Create and run the experiment

175def main():

Create experiment

180    experiment.create(name="fta", writers={'screen', 'labml'})

Create configs

182    conf = Configs()

Override configurations

184    experiment.configs(conf, {

Use character level tokenizer

186        'tokenizer': 'character',

Prompt separator is blank

188        'prompt_separator': '',

Starting prompt for sampling

190        'prompt': 'It is ',

Use Tiny Shakespeare dataset

192        'text': 'tiny_shakespeare',

Use a context size of

195        'seq_len': 256,

Train for 32 epochs

197        'epochs': 32,

Batch size

199        'batch_size': 16,

Switch between training and validation for times per epoch

201        'inner_iterations': 10,

Adam optimizer with no warmup

204        'optimizer.optimizer': 'Adam',
205        'optimizer.learning_rate': 3e-4,
206    })

Set model(s) for saving and loading

209    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

212    with experiment.start():

Run training

214        conf.run()

218if __name__ == '__main__':
219    main()