Fuzzy Tiling Activation Experiment

Open In Colab

Here we train a transformer that uses Fuzzy Tiling Activation in the Feed-Forward Network. We use it for a language model and train it on Tiny Shakespeare dataset for demonstration.

However, this is probably not the ideal task for FTA, and we believe FTA is more suitable for modeling data with continuous variables.

21import copy
22
23import torch
24import torch.nn as nn
25
26from labml import experiment
27from labml.configs import option
28from labml_nn.activations.fta import FTA
29from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
30from labml_nn.transformers import MultiHeadAttention, TransformerLayer
31from labml_nn.transformers.utils import subsequent_mask

FFN module with FTA activation

34class FeedForwardFTA(nn.Module):
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • activation is FTA activation module
  • dropout is dropout probability for the hidden layer
39    def __init__(self, d_model: int, d_ff: int,
40                 activation: FTA,
41                 dropout: float = 0.1):
48        super().__init__()

Layer one parameterized by weight and bias

50        self.layer1 = nn.Linear(d_model, d_ff)

Layer two parameterized by weight and bias

52        self.layer2 = nn.Linear(d_ff * activation.expansion_factor, d_model)

Hidden layer dropout

54        self.dropout = nn.Dropout(dropout)

Activation function

56        self.activation = activation
58    def forward(self, x: torch.Tensor):

60        x = self.activation(self.layer1(x))

Apply dropout

62        x = self.dropout(x)

64        return self.layer2(x)

Auto-Regressive model

This is an autoregressive transformer model that uses Feed-Forward Networks with (Fuzzy Tiling Activations)(index.html).

67class AutoregressiveTransformer(nn.Module):
  • n_tokens is the number of tokens in the vocabulary
  • d_model is the embedding size
  • n_layers is the number of transformer layers
  • layer is the layer. We use n_layers copies of this for the transformer.
75    def __init__(self, n_tokens: int, d_model: int, n_layers: int, layer: TransformerLayer):
82        super().__init__()

Transformer with n_layers layers

84        self.transformer_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])

Token embedding layer

87        self.emb = nn.Embedding(n_tokens, d_model)

Readout layer

89        self.readout = nn.Linear(d_model, n_tokens)

The mask will be initialized on the first call

92        self.mask = None
  • x are the input tokens of shape [seq_len, batch_size]
94    def forward(self, x: torch.Tensor):

Create auto-regressive mask

99        if self.mask is None or self.mask.size(0) != len(x):

Subsequent mask, will mask out tokens from seeing future tokens

101            self.mask = subsequent_mask(len(x)).to(x.device)

Get the token embeddings

104        x = self.emb(x)

Transformer encoder

106        for layer in self.transformer_layers:
107            x = layer(x=x, mask=self.mask)

Get logits

109        x = self.readout(x)

Return results

112        return x, None

Configurations

This inherits from NLPAutoRegressionConfigs

115class Configs(NLPAutoRegressionConfigs):

Model

124    model: AutoregressiveTransformer

Number of layers

127    n_layers: int = 4

and for DeepNorm

130    deep_norm_alpha: float
131    deep_norm_beta: float

Number of heads in the attention

134    n_heads: int = 4

Embedding size

136    d_model: int = 256

Size of each attention head

138    d_k: int = 16

Feed forward layer size

140    d_ff: int = 256

FTA

143    fta_lower_limit: float = -1.
144    fta_upper_limit: float = +1.
145    fta_delta: float = 0.2
146    fta_eta: float = 0.05

Initialize the model

149@option(Configs.model)
150def _model(c: Configs):

Create FTA activation module

156    fta = FTA(c.fta_lower_limit, c.fta_upper_limit, c.fta_delta, c.fta_eta)

Create the transformer. We re-use TransformerLayer and MultiHeadAttention implementations.

160    m = AutoregressiveTransformer(c.n_tokens, c.d_model, c.n_layers,
161                                  TransformerLayer(d_model=c.d_model,
162                                                   feed_forward=FeedForwardFTA(d_model=c.d_model,
163                                                                               d_ff=c.d_ff,
164                                                                               activation=fta,
165                                                                               dropout=0.1),
166                                                   self_attn=MultiHeadAttention(c.n_heads, c.d_model,
167                                                                                dropout_prob=0.0),
168                                                   dropout_prob=0.0))

Move to the device

171    return m.to(c.device)

Create and run the experiment

174def main():

Create experiment

179    experiment.create(name="fta", writers={'screen', 'labml'})

Create configs

181    conf = Configs()

Override configurations

183    experiment.configs(conf, {

Use character level tokenizer

185        'tokenizer': 'character',

Prompt separator is blank

187        'prompt_separator': '',

Starting prompt for sampling

189        'prompt': 'It is ',

Use Tiny Shakespeare dataset

191        'text': 'tiny_shakespeare',

Use a context size of

194        'seq_len': 256,

Train for 32 epochs

196        'epochs': 32,

Batch size

198        'batch_size': 16,

Switch between training and validation for times per epoch

200        'inner_iterations': 10,

Adam optimizer with no warmup

203        'optimizer.optimizer': 'Adam',
204        'optimizer.learning_rate': 3e-4,
205    })

Set model(s) for saving and loading

208    experiment.add_pytorch_models({'model': conf.model})

Start the experiment

211    with experiment.start():

Run training

213        conf.run()

217if __name__ == '__main__':
218    main()