28from functools import partial
29from typing import Dict, NamedTuple, Tuple, Any, Callable
30from typing import List, TypeVar, Generic
31from typing import Union, Optional
32
33import jax
34import jax.numpy as jnp
35import numpy as np
36
37from labml import lab, monit, experiment, tracker
38from labml import logger
39from labml.logger import Text
40from labml.utils.download import download_file

Module

This is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.

You can skip these modules to get into the models directly.

The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.

This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.

43class Module:

Store all parameters and sub-modules in dictionaries

63    _submodules: Dict[str, 'Module']
64    _params: Dict[str, jnp.ndarray]

Initialize

66    def __init__(self):
68        self._params = {}
69        self._submodules = {}

Get attribute

We override the get attribute operation. So when you reference an attribute with model.attribute this function gets called.

Read this guide if you are not familiar with Python magic methods.

71    def __getattr__(self, attr_name: str):

If the attribute is a parameter

83        if attr_name in self._params:
84            return self._params[attr_name]

If the attribute is a sub-module

86        elif attr_name in self._submodules:
87            return self._submodules[attr_name]

Otherwise fallback to normal attributes. The attributes are stored in __dict__ by Python.

90        else:
91            return self.__dict__[attr_name]

Set attribute

We override the set attribute operation. So when you assign an attribute with model.attribute this function gets called.

93    def __setattr__(self, key: str, value: Any):

If the value is also a module

102        if isinstance(value, Module):
103            self._submodules[key] = value

If the value is a JAX array

105        elif isinstance(value, jnp.ndarray):
106            self._params[key] = value

Otherwise add it to __dict__

108        else:
109            self.__dict__[key] = value

Clear parameters

These clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.

111    def _clear_params(self):

Clear parameters of the module

120        self._params = {}

Recursively clear parameters of submodules

122        for sm in self._submodules.values():
123            sm._clear_params()

Collect all the parameters

This recursively collects all the parameters of the module and sub-modules into a dictionary.

125    def get_params(self) -> Dict[str, jnp.ndarray]:

Parameters of the model

133        params = self._params.copy()

Parameters of the submodules

135        for sm_name, sm in self._submodules.items():
136            for name, value in sm.get_params().items():

The dictionary keys are of the form module_name/module_name/param_name

138                params[sm_name + "/" + name] = value

140        return params

Set all the parameters

142    def _set_params(self, params: Dict[str, jnp.ndarray]):

Iterate through parameters. Their names have the form module_name/module_name/param_name

149        for name, value in params.items():

Split to get module names and parameter name

151            self._set_param(name.split("/"), value)

Set a single parameter

This is called by _set_params

153    def _set_param(self, param_path: List[str], value: jnp.ndarray):

No module names; i.e. a parameter of this module

160        if len(param_path) == 1:
161            self._params[param_path[0]] = value

Parameter of a submodule

163        else:
164            self._submodules[param_path[0]]._set_param(param_path[1:], value)

Transform a member method to a pure function

This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.

For example,

params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
166    def purify(self, method: Callable) -> Callable:
182        def pure_method(params: Dict[str, jnp.array], *args):

Clear parameters in the object

184            self._clear_params()

Assign the passed parameters

186            self._set_params(params)

Invoke the method

188            result = method(*args)

Return the result

190            return result

193        return pure_method

Type for generics in the module list class

197M = TypeVar('M', bound=Module)

Module list

This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.

200class ModuleList(Module, Generic[M]):

For list of modules

209    _submodules: List[M]

Initialize with a list of modules.

211    def __init__(self, modules: List[M]):
215        super().__init__()
216        self._submodules = modules

Get the idx -th module

218    def __getitem__(self, idx: int) -> M:
222        return self._submodules[idx]

This is not supported

224    def __setitem__(self, key, value):
228        raise NotImplementedError

Number of modules

230    def __len__(self):
234        return len(self._submodules)

Override __getattr__ of Module

236    def __getattr__(self, item):
240        return self.__dict__[item]

Override __setattr__ of Module

242    def __setattr__(self, key, value):
246        self.__dict__[key] = value

Clear all parameters

248    def _clear_params(self):
252        self._params = {}
253        for sm in self._submodules:
254            sm._clear_params()

Get all parameters

256    def get_params(self):
260        params = self._params
261        for i, sm in enumerate(self._submodules):
262            for name, value in sm.get_params().items():
263                params[f'{i}/{name}'] = value
264        return params

Set a parameter

266    def _set_param(self, param_path: List[str], value: jnp.ndarray):
270        self._submodules[int(param_path[0])]._set_param(param_path[1:], value)

Embedding layer

This maintains embeddings by id.

273class Embedding(Module):
  • rnd_key is the PRNG state
  • n_embeddings is the number of embeddings
  • n_dim is the size of an embedding
282    def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):
288        super().__init__()

Embeddings are initialized from

290        self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))

Return the embeddings for the given ids

292    def __call__(self, ids: jnp.ndarray):
296        return self.embeddings[ids, :]

Embed tokens and add parameterized positional encodings

This is based on our PyTorch implementation.

299class EmbeddingsWithLearnedPositionalEncoding(Module):
  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the embedding size
  • max_len is the maximum sequence length (to initialize positional encodings)
309    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):
316        super().__init__()

Embeddings

318        self.embeddings = Embedding(rnd_key, n_vocab, d_model)

Positional encodings coefficient

320        self.pe_coef = 1 / d_model ** 0.5

Positional encodings initialized to zeros

322        self.positional_encodings = jnp.zeros((max_len, d_model))
324    def __call__(self, x: jnp.ndarray):

Get positional encodings

326        pe = self.positional_encodings[:x.shape[0]]

Get embeddings and add positional encodings

328        return self.embeddings(x) * self.pe_coef + pe

Linear Layer

This is a simple linear layer with a weight matrix and a bias vector

331class Linear(Module):
  • rnd_key is the PRNG state
  • in_features is the number of features in the input
  • out_features is the number of features in the output
340    def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):
346        super().__init__()

Initialize weights to

349        rnd_range = 1 / in_features ** 0.5
350        self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
351                                         minval=-rnd_range, maxval=rnd_range)

Initialize the biases to

353        self.bias = jnp.zeros((out_features,))
355    def __call__(self, x: jnp.ndarray):

Multiply by weights and add the bias

357        return jnp.matmul(x, self.weight) + self.bias

Layer Normalization

This implements the the layer normalization from the paper Layer Normalization.

When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .

This is based on our PyTorch implementation.

360class LayerNorm(Module):
  • normalized_shape is the shape of the elements (except the batch). The input should then be
  • eps is , used in for numerical stability
  • elementwise_affine is whether to scale and shift the normalized value
380    def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
381                 eps: float = 1e-5, elementwise_affine: bool = True):
389        super().__init__()
390
391        self.eps = eps
392        self.elementwise_affine = elementwise_affine
393        self.normalized_shape = tuple(normalized_shape)

Create parameters for and for gain and bias

396        if elementwise_affine:
397            self.gain = jnp.ones(normalized_shape)
398            self.bias = jnp.zeros(normalized_shape)
400    def __call__(self, x: jnp.ndarray):

Sanity check to make sure the shapes match

402        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]

The exes to calculate the mean and variance on

405        axes = [-(i + 1) for i in range(len(self.normalized_shape))]

Calculate the mean of all elements; i.e. the means for each element

408        mean = x.mean(axis=axes, keepdims=True)

Calculate the squared mean of all elements; i.e. the means for each element

411        mean_2 = (x ** 2).mean(axis=axes, keepdims=True)

Variance of all element

413        var = mean_2 - mean ** 2

Normalize

415        x_norm = (x - mean) / (var + self.eps) ** 0.5

Scale and shift

418        if self.elementwise_affine:
419            x_norm = self.gain * x_norm + self.bias

422        return x_norm

Multi-Head Attention Module

This computes scaled multi-headed attention from the paper Attention Is All You Need for given query , key and value vectors.

In simple terms, it finds keys that matches the query, and gets the values of those keys.

It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.

Softmax is calculated along the axis of of the sequence (or time) for keys.

This is based on our PyTorch implementation.

425class MultiHeadAttention(Module):
  • rnd_key is the PRNG state
  • heads is the number of heads.
  • d_model is the number of features in the query , key and value vectors.
451    def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
458        super().__init__()

Split the PRNG state

461        _, *rnd_keys = jax.random.split(rnd_key, 5)

Number of features per head

464        self.d_k = d_model // heads

Number of heads

466        self.heads = heads

These transform the query , key and value vectors for multi-headed attention.

469        self.query = Linear(rnd_keys[0], d_model, d_model)
470        self.key = Linear(rnd_keys[1], d_model, d_model)
471        self.value = Linear(rnd_keys[2], d_model, d_model)

Output layer

474        self.output = Linear(rnd_keys[3], d_model, d_model)

Scaling factor before the softmax

476        self.scale = 1 / self.d_k ** 0.5

query , key and value are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model] .

mask has shape [seq_len, seq_len] and mask[i, j] indicates whether query at position i can see key-value at position j .

478    def __call__(self, *,
479                 query: jnp.ndarray,
480                 key: jnp.ndarray,
481                 value: jnp.ndarray,
482                 mask: Optional[jnp.ndarray] = None):

Get sequence length

493        seq_len = len(query)
494
495        if mask is not None:

Check mask shape

497            assert mask.shape[0] == query.shape[0]
498            assert mask.shape[1] == key.shape[0]

Same mask applied to all heads.

501            mask = mask[:, :, None]

Apply linear transformations

504        query = self.query(query)
505        key = self.key(key)
506        value = self.value(value)

Reshape to split into heads Input has shape [seq_len, batch_size, d_model] . We split the last dimension into heads and d_k .

511        query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
512        key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
513        value = value.reshape(*value.shape[:-1], self.heads, self.d_k)

Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads] .

518        scores = jnp.einsum('ihd,jhd->ijh', query, key)

Scale scores

521        scores *= self.scale

Apply mask

524        if mask is not None:
525            scores = scores + (mask == 0) * float('-inf')

attention along the key sequence dimension

529        attn = jax.nn.softmax(scores, axis=1)

Multiply by values

533        x = jnp.einsum("ijh,jhd->ihd", attn, value)

Concatenate multiple heads

536        x = x.reshape(seq_len, -1)

Output layer

539        return self.output(x)

Position-wise Feed-Forward layer

This is based on our PyTorch implementation.

542class FeedForward(Module):
  • rnd_key is the PRNG state
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • activation is the activation function
552    def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
553                 activation=jax.nn.relu):
560        super().__init__()

Split the PRNG state

562        _, *rnd_keys = jax.random.split(rnd_key, 5)

Layer one parameterized by weight and bias

565        self.layer1 = Linear(rnd_keys[0], d_model, d_ff)

Layer one parameterized by weight and bias

567        self.layer2 = Linear(rnd_keys[1], d_ff, d_model)

Activation function

569        self.activation = activation
571    def __call__(self, x: jnp.ndarray):

573        x = self.activation(self.layer1(x))

575        return self.layer2(x)

Transformer Layer

This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.

578class TransformerLayer(Module):
  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
588    def __init__(self,
589                 d_model: int,
590                 self_attn: MultiHeadAttention,
591                 feed_forward: FeedForward):
597        super().__init__()
598        self.size = d_model
599        self.self_attn = self_attn
600        self.feed_forward = feed_forward
601        self.norm_self_attn = LayerNorm([d_model])
602        self.norm_ff = LayerNorm([d_model])
604    def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):

Normalize the vectors before doing self attention

606        z = self.norm_self_attn(x)

Run through self attention, i.e. keys and values are from self

608        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
609        x = x + self_attn

Normalize for feed-forward

612        z = self.norm_ff(x)

Pass through the feed-forward network

614        ff = self.feed_forward(z)

Add the feed-forward results

616        x = x + ff

618        return x

Cross Entropy Loss

621class CrossEntropyLoss(Module):
628    def __init__(self):
629        super().__init__()

Use jax.vmap to vectorize the loss function

632        self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))
634    def _loss(self, output: jnp.ndarray, target: jnp.ndarray):

636        return -jax.nn.log_softmax(output)[target]
  • output is the model outputs of shape [seq_len, n_vocab]
  • target is the target of shape [seq_len]
638    def __call__(self, output: jnp.ndarray, target: jnp.ndarray):

Use the vectorized loss function and calculate the mean.

We could have used a for loop to calculate the losses but using vmap is about 10X faster

647        return self._loss_vmap(output, target).mean()

Autoregressive Transformer

This is the transformer decode with embedding and output layers.

650class AutoregressiveTransformer(Module):
658    layers: ModuleList[TransformerLayer]
  • rnd_key is the PRNG state
  • n_vocab is the vocabulary size
  • d_model is the number of features in a token embedding
  • n_layers is the number of transformer layers
  • heads is the number of attention heads
  • d_ff is the number of features in the hidden layer of the FFN
660    def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):
669        super().__init__()
670        self.n_vocab = n_vocab
671        self.d_model = d_model
672        self.loss_func = CrossEntropyLoss()

For transformer layers

675        layers = []
676        for i in range(n_layers):

Split PRNG state

678            rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)

Create a transformer layer

680            attn = MultiHeadAttention(mha_key, heads, d_model)
681            ffn = FeedForward(ffn_key, d_model, d_ff)
682            layers.append(TransformerLayer(d_model, attn, ffn))

Make a module list

684        self.layers = ModuleList(layers)

Split PRNG state

687        rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)

Create embedding layer

689        self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)

Final normalization and output layer

691        self.norm = LayerNorm([d_model])
692        self.output = Linear(out_key, d_model, n_vocab)
694    def __call__(self, x: jnp.ndarray):

Get sequence length

696        seq_len = len(x)

A mask for attention so that a token can only see tokens before that

698        mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))

Get embeddings with positional encodings

700        x = self.embeddings(x)

Apply the transformer layers

702        for i in range(len(self.layers)):
703            x = self.layers[i](x, mask)

Final normalization and linear transformation to get the logits

706        return self.output(self.norm(x))

Calculate the loss

708    def get_loss(self, x: jnp.ndarray):

Get model outputs

713        output = self(x)

Cross entropy loss

715        return self.loss_func(output[:-1], x[1:])

Sample

The starting sequence is given by seq and we greedily sample `length1 tokens

717    def sample(self, seq: jnp.ndarray, length: int = 20):
723        for i in range(length):

Sample the highest probability token

725            idx = jnp.argmax(self(seq)[-1])

Add it to the sequence

727            seq = jnp.concatenate((seq, idx[None]))

Return the sampled sequence

730        return seq

This is a named tuple for storing Adam optimizer state for a parameter

733class AdamState(NamedTuple):
737    m: jnp.ndarray
738    v: jnp.ndarray

Adam Optimizer

This is from paper Adam: A Method for Stochastic Optimization.

For parameter and gradient at step , the Adam update is,

where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.

741class Adam:
  • params is the tree-map of parameters
  • lr is the learning rate
  • betas is a tuple of (, )
  • eps is `
767    def __init__(self, params: Dict,
768                 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
769                 eps: float = 1e-16, ):
777        super().__init__()
778        self.lr = lr
779        self.betas = betas
780        self.eps = eps

States for each parameter

783        self.states = jax.tree.map(self._init_state, params)

Optimized step function

785        self._step_jit = jax.jit(self._step)

Number of steps taken

787        self._n_steps = 0

Optimized update state function

789        self._update_state_jit = jax.jit(self._update_state)

Initialize the state for a given parameter

791    def _init_state(self, param: jnp.ndarray):
795        return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))

Step function

  • params is a tree-map of parameters
  • grads is a tree-map of gradients
797    def step(self, params: Dict, grads: Dict):

Increment step

805        self._n_steps += 1

Update states for each parameter

807        self.states = jax.tree.map(self._update_state_jit, grads, self.states)

Return updated parameters

809        return jax.tree.map(partial(self._step_jit, self._n_steps), params, self.states)

Update parameters

This performs a Adam update on the given parameter

811    def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):

Bias corrections for : and for :

819        bias_correction = [1 - beta ** n_steps for beta in self.betas]

Uncorrected first and second moments and

821        m, v = state

824        step_size = self.lr * (bias_correction[1] ** 0.5) / bias_correction[0]

826        den = (v ** 0.5) + self.eps

830        return param - step_size * m / den

Update state

This updates uncorrected first and second moments and

832    def _update_state(self, grad, state: AdamState):

Uncorrected first and second moments and

839        m, v = state

Clip gradients

841        grad = jnp.clip(grad, -1, 1)

843        m = self.betas[0] * m + grad * (1 - self.betas[0])

845        v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])

Return the new state

848        return AdamState(m, v)

Tiny Shakespeare dataset

851class TinyShakespeare:
  • rnd_key is the PRNG state
  • seq_len is the sequence length of a sample
  • batch_size is the batch size
858    def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):
865        self.batch_size = batch_size

PRNG key for shuffling the samples

867        _, self.rnd_key = jax.random.split(rnd_key)

Local path of the text file

870        path = lab.get_data_path() / 'tiny_shakespeare.txt'

Download if it doesn't exist

872        url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
873        if not path.exists():
874            download_file(url, path)

Read the file

877        with open(str(path), 'r') as f:
878            self.text = f.read()

Get the characters/tokens

881        tokens = sorted(list(set(self.text)))

Number of tokens

884        self.n_tokens = len(tokens)

Map tokens to ids

886        self.stoi = {t: i for i, t in enumerate(tokens)}

Id to token/character

888        self.itos = tokens

As a list of ids

891        data = jnp.array([self.stoi[s] for s in list(self.text)])

Number of batches

893        self.n_batches = len(data) // (seq_len * batch_size)

Truncate

895        data = data[:self.n_batches * seq_len * batch_size]

Reshape into a samples (better to use random offsets, but lets ignore that here)

897        self.data = data.reshape((-1, seq_len))

List of sample indexes

899        self.idx = jnp.arange(len(self.data))

Setup for iteration

901    def __iter__(self):

Iteration step

906        self._iter_idx = 0

Split PRNG key

908        self.rnd_key, rnd_key = jax.random.split(self.rnd_key)

Shuffle sample indexes

910        self.idx = jax.random.permutation(rnd_key, self.idx)

913        return self

Number of batches

915    def __len__(self):
919        return self.n_batches

Get next batch

921    def __next__(self):

Stop iteration after iterating through all batches

927        if self._iter_idx >= self.n_batches:
928            raise StopIteration()

Sample indexes for the batch

931        idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]

Increment iteration step

933        self._iter_idx += 1

Return samples

936        return self.data[idx]

Run the experiment

939def main():

Create experiment

947    experiment.create(name='jax')

Create PRNG key

949    rnd_key = jax.random.PRNGKey(0)

Create dataset

951    dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)

Create the model

954    model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
955                                      d_model=128, n_layers=3, heads=8, d_ff=512)

Get model parameters

957    params = model.get_params()

JAX compiled pure sampling function

960    pure_sample_fn = jax.jit(model.purify(model.sample))

JAX compiled pure function to get logits for a batch. First we transform model.__call__ to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes specifies which arguments to parallelize and along which axis. (None, 0) means we have the same parameters but parallelize the inputs across the first axis. out_axes specifies along which axis to merge the results.

968    pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
969                                       in_axes=(None, 0), out_axes=0))

Similarly we vectorize loss computation

971    pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
972                                    in_axes=(None, 0), out_axes=0))

A function to get mean loss

975    def get_loss(params, seq):
976        return pure_loss_fn(params, seq).mean()

A function to compute gradients for the first argument (parameters)

979    grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))

Create optimizer

982    optimizer = Adam(params)

Start the experiment

985    with experiment.start():

Iterate for 32 epochs

987        for epoch in monit.loop(32):

Iterate through batches

989            for data in monit.iterate('Train', dataset):

Compute and log the loss

991                loss = get_loss(params, data)
992                tracker.save('loss', np.asarray(loss))

Get the gradients

994                grads = grad_loss_fn(params, data)

Update parameters

996                params = optimizer.step(params, grads)

999            tracker.new_line()

Log a sample after each epoch

1001            prompt = [dataset.stoi[c] for c in 'It ']
1002            sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
1003            sampled = ''.join([dataset.itos[i] for i in sampled])
1004            sampled = sampled.replace('\n', '\\n')
1005            logger.log(('It ', Text.meta), (sampled, Text.value))

1009if __name__ == '__main__':
1010    main()