Distilling the Knowledge in a Neural Network

This is a PyTorch implementation/tutorial of the paper Distilling the Knowledge in a Neural Network.

It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.

A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.

The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit 7, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.

Soft Targets

The probabilities are usually computed with a softmax operation,

where is the probability for class and is the logit.

We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).

One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature ,

where higher values for will produce softer probabilities.


Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.

The dataset for distillation is called the transfer set, and the paper suggests using the same training data.

Our experiment

We train on CIFAR-10 dataset. We train a large model that has parameters with dropout and it gives an accuracy of 85% on the validation set. A small model with parameters gives an accuracy of 80%.

We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.

View Run

74import torch
75import torch.nn.functional
76from torch import nn
78from labml import experiment, tracker
79from labml.configs import option
80from labml_helpers.train_valid import BatchIndex
81from labml_nn.distillation.large import LargeModel
82from labml_nn.distillation.small import SmallModel
83from labml_nn.experiments.cifar10 import CIFAR10Configs


This extends from CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

86class Configs(CIFAR10Configs):

The small model

94    model: SmallModel

The large model

96    large: LargeModel

KL Divergence loss for soft targets

98    kl_div_loss = nn.KLDivLoss(log_target=True)

Cross entropy loss for true label loss

100    loss_func = nn.CrossEntropyLoss()


102    temperature: float = 5.

Weight for soft targets loss.

The gradients produced by soft targets get scaled by . To compensate for this the paper suggests scaling the soft targets loss by a factor of

108    soft_targets_weight: float = 100.

Weight for true label cross entropy loss

110    label_loss_weight: float = 0.5

Training/validation step

We define a custom training/validation step to include the distillation

112    def step(self, batch: any, batch_idx: BatchIndex):

Training/Evaluation mode for the small model

120        self.model.train(self.mode.is_train)

Large model in evaluation mode

122        self.large.eval()

Move data to the device

125        data, target = batch[0].to(self.device), batch[1].to(self.device)

Update global step (number of samples processed) when in training mode

128        if self.mode.is_train:
129            tracker.add_global_step(len(data))

Get the output logits, , from the large model

132        with torch.no_grad():
133            large_logits = self.large(data)

Get the output logits, , from the small model

136        output = self.model(data)

Soft targets

140        soft_targets = nn.functional.log_softmax(large_logits / self.temperature, dim=-1)

Temperature adjusted probabilities of the small model

143        soft_prob = nn.functional.log_softmax(output / self.temperature, dim=-1)

Calculate the soft targets loss

146        soft_targets_loss = self.kl_div_loss(soft_prob, soft_targets)

Calculate the true label loss

148        label_loss = self.loss_func(output, target)

Weighted sum of the two losses

150        loss = self.soft_targets_weight * soft_targets_loss + self.label_loss_weight * label_loss

Log the losses

152        tracker.add({"loss.kl_div.": soft_targets_loss,
153                     "loss.nll": label_loss,
154                     "loss.": loss})

Calculate and log accuracy

157        self.accuracy(output, target)
158        self.accuracy.track()

Train the model

161        if self.mode.is_train:

Calculate gradients

163            loss.backward()

Take optimizer step

165            self.optimizer.step()

Log the model parameters and gradients on last batch of every epoch

167            if batch_idx.is_last:
168                tracker.add('model', self.model)

Clear the gradients

170            self.optimizer.zero_grad()

Save the tracked metrics

173        tracker.save()

Create large model

177def _large_model(c: Configs):
181    return LargeModel().to(c.device)

Create small model

185def _small_student_model(c: Configs):
189    return SmallModel().to(c.device)
192def get_saved_model(run_uuid: str, checkpoint: int):
197    from labml_nn.distillation.large import Configs as LargeConfigs

In evaluation mode (no recording)

200    experiment.evaluate()

Initialize configs of the large model training experiment

202    conf = LargeConfigs()

Load saved configs

204    experiment.configs(conf, experiment.load_configs(run_uuid))

Set models for saving/loading

206    experiment.add_pytorch_models({'model': conf.model})

Set which run and checkpoint to load

208    experiment.load(run_uuid, checkpoint)

Start the experiment - this will load the model, and prepare everything

210    experiment.start()

Return the model

213    return conf.model

Train a small model with distillation

216def main(run_uuid: str, checkpoint: int):

Load saved model

221    large_model = get_saved_model(run_uuid, checkpoint)

Create experiment

223    experiment.create(name='distillation', comment='cifar10')

Create configurations

225    conf = Configs()

Set the loaded large model

227    conf.large = large_model

Load configurations

229    experiment.configs(conf, {
230        'optimizer.optimizer': 'Adam',
231        'optimizer.learning_rate': 2.5e-4,
232        'model': '_small_student_model',
233    })

Set model for saving/loading

235    experiment.add_pytorch_models({'model': conf.model})

Start experiment from scratch

237    experiment.load(None, None)

Start the experiment and run the training loop

239    with experiment.start():
240        conf.run()

244if __name__ == '__main__':
245    main('d46cd53edaec11eb93c38d6538aee7d6', 1_000_000)