28from functools import partial
29from typing import Dict, NamedTuple, Tuple, Any, Callable
30from typing import List, TypeVar, Generic
31from typing import Union, Optional
32
33import jax
34import jax.numpy as jnp
35import numpy as np
36
37from labml import lab, monit, experiment, tracker
38from labml import logger
39from labml.logger import Text
40from labml.utils.download import download_fileThis is a base class for all modules. It handles parameters and transforms methods to pure functions for JAX to compile and differentiate.
You can skip these modules to get into the models directly.
The modules stores parameters and sub-modules separately. When we want to transform any method to a pure function, we pass the parameters of the module and the sub-module as an argument and assign the passed values to class.
This is based on a blog post: From PyTorch to JAX: towards neural net frameworks that purify stateful code.
43class Module:Store all parameters and sub-modules in dictionaries
63 _submodules: Dict[str, 'Module']
64 _params: Dict[str, jnp.ndarray]Initialize
66 def __init__(self):68 self._params = {}
69 self._submodules = {}We override the get attribute operation. So when you reference an attribute with model.attribute
this function gets called.
Read this guide if you are not familiar with Python magic methods.
71 def __getattr__(self, attr_name: str):If the attribute is a parameter
83 if attr_name in self._params:
84 return self._params[attr_name]If the attribute is a sub-module
86 elif attr_name in self._submodules:
87 return self._submodules[attr_name]Otherwise fallback to normal attributes. The attributes are stored in __dict__
by Python.
90 else:
91 return self.__dict__[attr_name]We override the set attribute operation. So when you assign an attribute with model.attribute
this function gets called.
93 def __setattr__(self, key: str, value: Any):If the value is also a module
102 if isinstance(value, Module):
103 self._submodules[key] = valueIf the value is a JAX array
105 elif isinstance(value, jnp.ndarray):
106 self._params[key] = valueOtherwise add it to __dict__
108 else:
109 self.__dict__[key] = valueThese clears out all the parameters. This is used when a method is called as a pure function. We first clears out all the parameters and assigns the parameters passed to the pure function.
111 def _clear_params(self):Clear parameters of the module
120 self._params = {}Recursively clear parameters of submodules
122 for sm in self._submodules.values():
123 sm._clear_params()This recursively collects all the parameters of the module and sub-modules into a dictionary.
125 def get_params(self) -> Dict[str, jnp.ndarray]:Parameters of the model
133 params = self._params.copy()Parameters of the submodules
135 for sm_name, sm in self._submodules.items():
136 for name, value in sm.get_params().items():The dictionary keys are of the form module_name/module_name/param_name
138 params[sm_name + "/" + name] = value140 return params142 def _set_params(self, params: Dict[str, jnp.ndarray]):Iterate through parameters. Their names have the form module_name/module_name/param_name
149 for name, value in params.items():Split to get module names and parameter name
151 self._set_param(name.split("/"), value)153 def _set_param(self, param_path: List[str], value: jnp.ndarray):No module names; i.e. a parameter of this module
160 if len(param_path) == 1:
161 self._params[param_path[0]] = valueParameter of a submodule
163 else:
164 self._submodules[param_path[0]]._set_param(param_path[1:], value)This transforms a member method to a pure function that accepts a dictionary of parameters as an argument.
For example,
params = model.get_params()
pure_function = model.purify(model.calculate_loss)
output = pure_function(params, data)
166 def purify(self, method: Callable) -> Callable:182 def pure_method(params: Dict[str, jnp.array], *args):Clear parameters in the object
184 self._clear_params()Assign the passed parameters
186 self._set_params(params)Invoke the method
188 result = method(*args)Return the result
190 return result193 return pure_methodType for generics in the module list class
197M = TypeVar('M', bound=Module)This stores a list of modules. We needed this for transformer decoder to hold the list of transformer layers.
200class ModuleList(Module, Generic[M]):For list of modules
209 _submodules: List[M]Initialize with a list of modules.
211 def __init__(self, modules: List[M]):215 super().__init__()
216 self._submodules = modulesidx
-th module218 def __getitem__(self, idx: int) -> M:222 return self._submodules[idx]This is not supported
224 def __setitem__(self, key, value):228 raise NotImplementedError230 def __len__(self):234 return len(self._submodules) Override __getattr__
of Module
236 def __getattr__(self, item):240 return self.__dict__[item] Override __setattr__
of Module
242 def __setattr__(self, key, value):246 self.__dict__[key] = value248 def _clear_params(self):252 self._params = {}
253 for sm in self._submodules:
254 sm._clear_params()256 def get_params(self):260 params = self._params
261 for i, sm in enumerate(self._submodules):
262 for name, value in sm.get_params().items():
263 params[f'{i}/{name}'] = value
264 return params266 def _set_param(self, param_path: List[str], value: jnp.ndarray):270 self._submodules[int(param_path[0])]._set_param(param_path[1:], value)273class Embedding(Module):rnd_key
is the PRNG state n_embeddings
is the number of embeddings n_dim
is the size of an embedding282 def __init__(self, rnd_key: jax.random.PRNGKey, n_embeddings: int, n_dim: int):288 super().__init__()Embeddings are initialized from
290 self.embeddings = jax.random.normal(rnd_key, (n_embeddings, n_dim))Return the embeddings for the given ids
292 def __call__(self, ids: jnp.ndarray):296 return self.embeddings[ids, :]This is based on our PyTorch implementation.
299class EmbeddingsWithLearnedPositionalEncoding(Module):rnd_key
is the PRNG state n_vocab
is the vocabulary size d_model
is the embedding size max_len
is the maximum sequence length (to initialize positional encodings)309 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, max_len: int = 4096):316 super().__init__()Embeddings
318 self.embeddings = Embedding(rnd_key, n_vocab, d_model)Positional encodings coefficient
320 self.pe_coef = 1 / d_model ** 0.5Positional encodings initialized to zeros
322 self.positional_encodings = jnp.zeros((max_len, d_model))324 def __call__(self, x: jnp.ndarray):Get positional encodings
326 pe = self.positional_encodings[:x.shape[0]]Get embeddings and add positional encodings
328 return self.embeddings(x) * self.pe_coef + pe331class Linear(Module):rnd_key
is the PRNG state in_features
is the number of features in the input out_features
is the number of features in the output340 def __init__(self, rnd_key: jax.random.PRNGKey, in_features: int, out_features: int):346 super().__init__()Initialize weights to
349 rnd_range = 1 / in_features ** 0.5
350 self.weight = jax.random.uniform(rnd_key, (in_features, out_features),
351 minval=-rnd_range, maxval=rnd_range)Initialize the biases to
353 self.bias = jnp.zeros((out_features,))355 def __call__(self, x: jnp.ndarray):Multiply by weights and add the bias
357 return jnp.matmul(x, self.weight) + self.biasThis implements the the layer normalization from the paper Layer Normalization.
When input is a sequence of embeddings, where is the number of channels, is the length of the sequence. and .
This is based on our PyTorch implementation.
360class LayerNorm(Module):normalized_shape
is the shape of the elements (except the batch). The input should then be eps
is , used in for numerical stability elementwise_affine
is whether to scale and shift the normalized value380 def __init__(self, normalized_shape: Union[Tuple[int], List[int]], *,
381 eps: float = 1e-5, elementwise_affine: bool = True):389 super().__init__()
390
391 self.eps = eps
392 self.elementwise_affine = elementwise_affine
393 self.normalized_shape = tuple(normalized_shape)Create parameters for and for gain and bias
396 if elementwise_affine:
397 self.gain = jnp.ones(normalized_shape)
398 self.bias = jnp.zeros(normalized_shape)400 def __call__(self, x: jnp.ndarray):Sanity check to make sure the shapes match
402 assert self.normalized_shape == x.shape[-len(self.normalized_shape):]The exes to calculate the mean and variance on
405 axes = [-(i + 1) for i in range(len(self.normalized_shape))]Calculate the mean of all elements; i.e. the means for each element
408 mean = x.mean(axis=axes, keepdims=True)Calculate the squared mean of all elements; i.e. the means for each element
411 mean_2 = (x ** 2).mean(axis=axes, keepdims=True)Variance of all element
413 var = mean_2 - mean ** 2Normalize
415 x_norm = (x - mean) / (var + self.eps) ** 0.5Scale and shift
418 if self.elementwise_affine:
419 x_norm = self.gain * x_norm + self.bias422 return x_normThis computes scaled multi-headed attention from the paper Attention Is All You Need for given query
, key
and value
vectors.
In simple terms, it finds keys that matches the query, and gets the values of those keys.
It uses dot-product of query and key as the indicator of how matching they are. Before taking the the dot-products are scaled by . This is done to avoid large dot-product values causing softmax to give very small gradients when is large.
Softmax is calculated along the axis of of the sequence (or time) for keys.
This is based on our PyTorch implementation.
425class MultiHeadAttention(Module):rnd_key
is the PRNG state heads
is the number of heads. d_model
is the number of features in the query
, key
and value
vectors.451 def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):458 super().__init__()Split the PRNG state
461 _, *rnd_keys = jax.random.split(rnd_key, 5)Number of features per head
464 self.d_k = d_model // headsNumber of heads
466 self.heads = headsThese transform the query
, key
and value
vectors for multi-headed attention.
469 self.query = Linear(rnd_keys[0], d_model, d_model)
470 self.key = Linear(rnd_keys[1], d_model, d_model)
471 self.value = Linear(rnd_keys[2], d_model, d_model)Output layer
474 self.output = Linear(rnd_keys[3], d_model, d_model)Scaling factor before the softmax
476 self.scale = 1 / self.d_k ** 0.5 query
, key
and value
are the tensors that store collection of query, key and value vectors. They have shape [seq_len, d_model]
.
mask
has shape [seq_len, seq_len]
and mask[i, j]
indicates whether query at position i
can see key-value at position j
.
478 def __call__(self, *,
479 query: jnp.ndarray,
480 key: jnp.ndarray,
481 value: jnp.ndarray,
482 mask: Optional[jnp.ndarray] = None):Get sequence length
493 seq_len = len(query)
494
495 if mask is not None:Check mask shape
497 assert mask.shape[0] == query.shape[0]
498 assert mask.shape[1] == key.shape[0]Same mask applied to all heads.
501 mask = mask[:, :, None]Apply linear transformations
504 query = self.query(query)
505 key = self.key(key)
506 value = self.value(value)Reshape to split into heads Input has shape [seq_len, batch_size, d_model]
. We split the last dimension into heads
and d_k
.
511 query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
512 key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
513 value = value.reshape(*value.shape[:-1], self.heads, self.d_k)Compute attention scores . This gives a tensor of shape [seq_len, seq_len, heads]
.
518 scores = jnp.einsum('ihd,jhd->ijh', query, key)Scale scores
521 scores *= self.scaleApply mask
524 if mask is not None:
525 scores = scores + (mask == 0) * float('-inf')attention along the key sequence dimension
529 attn = jax.nn.softmax(scores, axis=1)Multiply by values
533 x = jnp.einsum("ijh,jhd->ihd", attn, value)Concatenate multiple heads
536 x = x.reshape(seq_len, -1)Output layer
539 return self.output(x)542class FeedForward(Module):rnd_key
is the PRNG state d_model
is the number of features in a token embedding d_ff
is the number of features in the hidden layer of the FFN activation
is the activation function 552 def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, d_ff: int,
553 activation=jax.nn.relu):560 super().__init__()Split the PRNG state
562 _, *rnd_keys = jax.random.split(rnd_key, 5)Layer one parameterized by weight and bias
565 self.layer1 = Linear(rnd_keys[0], d_model, d_ff)Layer one parameterized by weight and bias
567 self.layer2 = Linear(rnd_keys[1], d_ff, d_model)Activation function
569 self.activation = activation571 def __call__(self, x: jnp.ndarray):573 x = self.activation(self.layer1(x))575 return self.layer2(x)This is a transformer layer with multi-head attention and a position-wise feed-forward layer. We use pre-layer layer normalization.
578class TransformerLayer(Module):d_model
is the token embedding size self_attn
is the self attention module feed_forward
is the feed forward module588 def __init__(self,
589 d_model: int,
590 self_attn: MultiHeadAttention,
591 feed_forward: FeedForward):597 super().__init__()
598 self.size = d_model
599 self.self_attn = self_attn
600 self.feed_forward = feed_forward
601 self.norm_self_attn = LayerNorm([d_model])
602 self.norm_ff = LayerNorm([d_model])604 def __call__(self, x: jnp.ndarray, mask: jnp.ndarray):Normalize the vectors before doing self attention
606 z = self.norm_self_attn(x)Run through self attention, i.e. keys and values are from self
608 self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
609 x = x + self_attnNormalize for feed-forward
612 z = self.norm_ff(x)Pass through the feed-forward network
614 ff = self.feed_forward(z)Add the feed-forward results
616 x = x + ff618 return x621class CrossEntropyLoss(Module):628 def __init__(self):
629 super().__init__()Use jax.vmap
to vectorize the loss function
632 self._loss_vmap = jax.vmap(self._loss, in_axes=(0, 0,))634 def _loss(self, output: jnp.ndarray, target: jnp.ndarray):636 return -jax.nn.log_softmax(output)[target]output
is the model outputs of shape [seq_len, n_vocab]
target
is the target of shape [seq_len]
638 def __call__(self, output: jnp.ndarray, target: jnp.ndarray):Use the vectorized loss function and calculate the mean.
We could have used a for loop to calculate the losses but using vmap is about 10X faster
647 return self._loss_vmap(output, target).mean()650class AutoregressiveTransformer(Module):658 layers: ModuleList[TransformerLayer]rnd_key
is the PRNG state n_vocab
is the vocabulary size d_model
is the number of features in a token embedding n_layers
is the number of transformer layers heads
is the number of attention heads d_ff
is the number of features in the hidden layer of the FFN660 def __init__(self, rnd_key: jax.random.PRNGKey, n_vocab: int, d_model: int, n_layers: int, heads: int, d_ff: int):669 super().__init__()
670 self.n_vocab = n_vocab
671 self.d_model = d_model
672 self.loss_func = CrossEntropyLoss()For transformer layers
675 layers = []
676 for i in range(n_layers):Split PRNG state
678 rnd_key, mha_key, ffn_key = jax.random.split(rnd_key, 3)Create a transformer layer
680 attn = MultiHeadAttention(mha_key, heads, d_model)
681 ffn = FeedForward(ffn_key, d_model, d_ff)
682 layers.append(TransformerLayer(d_model, attn, ffn))Make a module list
684 self.layers = ModuleList(layers)Split PRNG state
687 rnd_key, emb_key, out_key = jax.random.split(rnd_key, 3)Create embedding layer
689 self.embeddings = EmbeddingsWithLearnedPositionalEncoding(emb_key, n_vocab, d_model)Final normalization and output layer
691 self.norm = LayerNorm([d_model])
692 self.output = Linear(out_key, d_model, n_vocab)694 def __call__(self, x: jnp.ndarray):Get sequence length
696 seq_len = len(x)A mask for attention so that a token can only see tokens before that
698 mask = jnp.tril(jnp.ones((seq_len, seq_len), bool))Get embeddings with positional encodings
700 x = self.embeddings(x)Apply the transformer layers
702 for i in range(len(self.layers)):
703 x = self.layers[i](x, mask)Final normalization and linear transformation to get the logits
706 return self.output(self.norm(x))708 def get_loss(self, x: jnp.ndarray):Get model outputs
713 output = self(x)Cross entropy loss
715 return self.loss_func(output[:-1], x[1:])717 def sample(self, seq: jnp.ndarray, length: int = 20):723 for i in range(length):Sample the highest probability token
725 idx = jnp.argmax(self(seq)[-1])Add it to the sequence
727 seq = jnp.concatenate((seq, idx[None]))Return the sampled sequence
730 return seqThis is a named tuple for storing Adam optimizer state for a parameter
733class AdamState(NamedTuple):737 m: jnp.ndarray
738 v: jnp.ndarrayThis is from paper Adam: A Method for Stochastic Optimization.
For parameter and gradient at step , the Adam update is,
where , , and are scalar hyper parameters. and are first and second order moments. and are biased corrected moments. is used as a fix for division by zero error, but also acts as a form of a hyper-parameter that acts against variance in gradients.
741class Adam:params
is the tree-map of parameters lr
is the learning rate betas
is a tuple of (, ) eps
is `767 def __init__(self, params: Dict,
768 lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999),
769 eps: float = 1e-16, ):777 super().__init__()
778 self.lr = lr
779 self.betas = betas
780 self.eps = epsStates for each parameter
783 self.states = jax.tree.map(self._init_state, params)Optimized step function
785 self._step_jit = jax.jit(self._step)Number of steps taken
787 self._n_steps = 0Optimized update state function
789 self._update_state_jit = jax.jit(self._update_state)Initialize the state for a given parameter
791 def _init_state(self, param: jnp.ndarray):795 return AdamState(jnp.zeros_like(param), jnp.zeros_like(param))797 def step(self, params: Dict, grads: Dict):Increment step
805 self._n_steps += 1Update states for each parameter
807 self.states = jax.tree.map(self._update_state_jit, grads, self.states)Return updated parameters
809 return jax.tree.map(partial(self._step_jit, self._n_steps), params, self.states)811 def _step(self, n_steps: int, param: jnp.ndarray, state: AdamState):Bias corrections for : and for :
819 bias_correction = [1 - beta ** n_steps for beta in self.betas]Uncorrected first and second moments and
821 m, v = state824 step_size = self.lr * (bias_correction[1] ** 0.5) / bias_correction[0]826 den = (v ** 0.5) + self.eps830 return param - step_size * m / den832 def _update_state(self, grad, state: AdamState):Uncorrected first and second moments and
839 m, v = stateClip gradients
841 grad = jnp.clip(grad, -1, 1)843 m = self.betas[0] * m + grad * (1 - self.betas[0])845 v = self.betas[1] * v + (grad ** 2) * (1 - self.betas[1])Return the new state
848 return AdamState(m, v)851class TinyShakespeare:rnd_key
is the PRNG state seq_len
is the sequence length of a sample batch_size
is the batch size858 def __init__(self, rnd_key: jax.random.PRNGKey, seq_len: int, batch_size: int):865 self.batch_size = batch_sizePRNG key for shuffling the samples
867 _, self.rnd_key = jax.random.split(rnd_key)Local path of the text file
870 path = lab.get_data_path() / 'tiny_shakespeare.txt'Download if it doesn't exist
872 url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
873 if not path.exists():
874 download_file(url, path)Read the file
877 with open(str(path), 'r') as f:
878 self.text = f.read()Get the characters/tokens
881 tokens = sorted(list(set(self.text)))Number of tokens
884 self.n_tokens = len(tokens)Map tokens to ids
886 self.stoi = {t: i for i, t in enumerate(tokens)}Id to token/character
888 self.itos = tokensAs a list of ids
891 data = jnp.array([self.stoi[s] for s in list(self.text)])Number of batches
893 self.n_batches = len(data) // (seq_len * batch_size)Truncate
895 data = data[:self.n_batches * seq_len * batch_size]Reshape into a samples (better to use random offsets, but lets ignore that here)
897 self.data = data.reshape((-1, seq_len))List of sample indexes
899 self.idx = jnp.arange(len(self.data))Setup for iteration
901 def __iter__(self):Iteration step
906 self._iter_idx = 0Split PRNG key
908 self.rnd_key, rnd_key = jax.random.split(self.rnd_key)Shuffle sample indexes
910 self.idx = jax.random.permutation(rnd_key, self.idx)913 return selfNumber of batches
915 def __len__(self):919 return self.n_batchesGet next batch
921 def __next__(self):Stop iteration after iterating through all batches
927 if self._iter_idx >= self.n_batches:
928 raise StopIteration()Sample indexes for the batch
931 idx = self.idx[self._iter_idx * self.batch_size:(self._iter_idx + 1) * self.batch_size]Increment iteration step
933 self._iter_idx += 1Return samples
936 return self.data[idx]939def main():Create experiment
947 experiment.create(name='jax')Create PRNG key
949 rnd_key = jax.random.PRNGKey(0)Create dataset
951 dataset = TinyShakespeare(rnd_key, seq_len=32, batch_size=128)Create the model
954 model = AutoregressiveTransformer(rnd_key, dataset.n_tokens,
955 d_model=128, n_layers=3, heads=8, d_ff=512)Get model parameters
957 params = model.get_params()JAX compiled pure sampling function
960 pure_sample_fn = jax.jit(model.purify(model.sample))JAX compiled pure function to get logits for a batch. First we transform model.__call__
to a pure function which accepts two arguments: parameters, and input sequence. Next we vectorize the function to process a batch of samples. in_axes
specifies which arguments to parallelize and along which axis. (None, 0)
means we have the same parameters but parallelize the inputs across the first axis. out_axes
specifies along which axis to merge the results.
968 pure_forward_fn = jax.jit(jax.vmap(model.purify(model.__call__),
969 in_axes=(None, 0), out_axes=0))Similarly we vectorize loss computation
971 pure_loss_fn = jax.jit(jax.vmap(model.purify(model.get_loss),
972 in_axes=(None, 0), out_axes=0))A function to get mean loss
975 def get_loss(params, seq):
976 return pure_loss_fn(params, seq).mean()A function to compute gradients for the first argument (parameters)
979 grad_loss_fn = jax.jit(jax.grad(get_loss, argnums=0))Create optimizer
982 optimizer = Adam(params)Start the experiment
985 with experiment.start():Iterate for 32 epochs
987 for epoch in monit.loop(32):Iterate through batches
989 for data in monit.iterate('Train', dataset):Compute and log the loss
991 loss = get_loss(params, data)
992 tracker.save('loss', np.asarray(loss))Get the gradients
994 grads = grad_loss_fn(params, data)Update parameters
996 params = optimizer.step(params, grads)999 tracker.new_line()Log a sample after each epoch
1001 prompt = [dataset.stoi[c] for c in 'It ']
1002 sampled = pure_sample_fn(params, jnp.array(prompt))[len(prompt):]
1003 sampled = ''.join([dataset.itos[i] for i in sampled])
1004 sampled = sampled.replace('\n', '\\n')
1005 logger.log(('It ', Text.meta), (sampled, Text.value))1009if __name__ == '__main__':
1010 main()