This is a PyTorch implementation/tutorial of the paper Distilling the Knowledge in a Neural Network.

It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.

A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.

The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit *7*, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.

The probabilities are usually computed with a softmax operation,

$q_{i}=∑_{j}exp(z_{j})exp(z_{i}) $

where $q_{i}$ is the probability for class $i$ and $z_{i}$ is the logit.

We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).

One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature $T$,

$q_{i}=∑_{j}exp(Tz_{j} )exp(Tz_{i} ) $

where higher values for $T$ will produce softer probabilities.

Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.

The dataset for distillation is called *the transfer set*, and the paper suggests using the same training data.

We train on CIFAR-10 dataset. We train a large model that has $14,728,266$ parameters with dropout and it gives an accuracy of 85% on the validation set. A small model with $437,034$ parameters gives an accuracy of 80%.

We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.

```
72import torch
73import torch.nn.functional
74from torch import nn
75
76from labml import experiment, tracker
77from labml.configs import option
78from labml_helpers.train_valid import BatchIndex
79from labml_nn.distillation.large import LargeModel
80from labml_nn.distillation.small import SmallModel
81from labml_nn.experiments.cifar10 import CIFAR10Configs
```

This extends from `CIFAR10Configs`

which defines all the dataset related configurations, optimizer, and a training loop.

`84class Configs(CIFAR10Configs):`

The small model

`92 model: SmallModel`

The large model

`94 large: LargeModel`

KL Divergence loss for soft targets

`96 kl_div_loss = nn.KLDivLoss(log_target=True)`

Cross entropy loss for true label loss

`98 loss_func = nn.CrossEntropyLoss()`

Temperature, $T$

`100 temperature: float = 5.`

Weight for soft targets loss.

The gradients produced by soft targets get scaled by $T_{2}1 $. To compensate for this the paper suggests scaling the soft targets loss by a factor of $T_{2}$

`106 soft_targets_weight: float = 100.`

Weight for true label cross entropy loss

`108 label_loss_weight: float = 0.5`

`110 def step(self, batch: any, batch_idx: BatchIndex):`

Training/Evaluation mode for the small model

`118 self.model.train(self.mode.is_train)`

Large model in evaluation mode

`120 self.large.eval()`

Move data to the device

`123 data, target = batch[0].to(self.device), batch[1].to(self.device)`

Update global step (number of samples processed) when in training mode

```
126 if self.mode.is_train:
127 tracker.add_global_step(len(data))
```

Get the output logits, $v_{i}$, from the large model

```
130 with torch.no_grad():
131 large_logits = self.large(data)
```

Get the output logits, $z_{i}$, from the small model

`134 output = self.model(data)`

Soft targets $p_{i}=∑_{j}exp(Tv_{j} )exp(Tv_{i} ) $

`138 soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)`

Temperature adjusted probabilities of the small model $q_{i}=∑_{j}exp(Tz_{j} )exp(Tz_{i} ) $

`141 soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)`

Calculate the soft targets loss

`144 soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)`

Calculate the true label loss

`146 label_loss = self.loss_func(output, target)`

Weighted sum of the two losses

`148 loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss`

Log the losses

```
150 tracker.add({"loss.kl_div.": soft_targets_loss,
151 "loss.nll": label_loss,
152 "loss.": loss})
```

Calculate and log accuracy

```
155 self.accuracy(output, target)
156 self.accuracy.track()
```

Train the model

`159 if self.mode.is_train:`

Calculate gradients

`161 loss.backward()`

Take optimizer step

`163 self.optimizer.step()`

Log the model parameters and gradients on last batch of every epoch

```
165 if batch_idx.is_last:
166 tracker.add('model', self.model)
```

Clear the gradients

`168 self.optimizer.zero_grad()`

Save the tracked metrics

`171 tracker.save()`

```
174@option(Configs.large)
175def _large_model(c: Configs):
```

`179 return LargeModel().to(c.device)`

```
182@option(Configs.model)
183def _small_student_model(c: Configs):
```

`187 return SmallModel().to(c.device)`

`190def get_saved_model(run_uuid: str, checkpoint: int):`

`195 from labml_nn.distillation.large import Configs as LargeConfigs`

In evaluation mode (no recording)

`198 experiment.evaluate()`

Initialize configs of the large model training experiment

`200 conf = LargeConfigs()`

Load saved configs

`202 experiment.configs(conf, experiment.load_configs(run_uuid))`

Set models for saving/loading

`204 experiment.add_pytorch_models({'model': conf.model})`

Set which run and checkpoint to load

`206 experiment.load(run_uuid, checkpoint)`

Start the experiment - this will load the model, and prepare everything

`208 experiment.start()`

Return the model

`211 return conf.model`

Train a small model with distillation

`214def main(run_uuid: str, checkpoint: int):`

Load saved model

`219 large_model = get_saved_model(run_uuid, checkpoint)`

Create experiment

`221 experiment.create(name='distillation', comment='cifar10')`

Create configurations

`223 conf = Configs()`

Set the loaded large model

`225 conf.large = large_model`

Load configurations

```
227 experiment.configs(conf, {
228 'optimizer.optimizer': 'Adam',
229 'optimizer.learning_rate': 2.5e-4,
230 'model': '_small_student_model',
231 })
```

Set model for saving/loading

`233 experiment.add_pytorch_models({'model': conf.model})`

Start experiment from scratch

`235 experiment.load(None, None)`

Start the experiment and run the training loop

```
237 with experiment.start():
238 conf.run()
```

```
242if __name__ == '__main__':
243 main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)
```