Training U-Net

This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.

Save the training images inside carvana/train folder and the masks in carvana/train_masks folder.

For simplicity, we do not do a training and validation split.

19import numpy as np
20import torch
22import torchvision.transforms.functional
23from torch import nn
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet


32class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

39    device: torch.device = DeviceConfigs()

U-Net model

42    model: UNet

Number of channels in the image. for RGB.

45    image_channels: int = 3

Number of channels in the output mask. for binary mask.

47    mask_channels: int = 1

Batch size

50    batch_size: int = 1

Learning rate

52    learning_rate: float = 2.5e-4

Number of training epochs

55    epochs: int = 4


58    dataset: CarvanaDataset


60    data_loader:

Loss function

63    loss_func = nn.BCELoss()

Sigmoid function for binary classification

65    sigmoid = nn.Sigmoid()

Adam optimizer

68    optimizer: torch.optim.Adam
70    def init(self):

Initialize the Carvana dataset

72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')

Initialize the model

75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)

Create dataloader

78        self.data_loader =, self.batch_size,
79                                                       shuffle=True, pin_memory=True)

Create optimizer

81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

Image logging

84        tracker.set_image("sample", True)

Sample images

86    @torch.no_grad()
87    def sample(self, idx=-1):

Get a random sample

93        x, _ = self.dataset[np.random.randint(len(self.dataset))]

Move data to device

95        x =

Get predicted mask

98        mask = self.sigmoid(self.model(x[None, :]))

Crop the image to the size of the mask

100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])

Log samples

102'sample', x * mask)

Train for an epoch

104    def train(self):

Iterate through the dataset. Use mix to sample times per epoch.

112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):

Increment global step

114            tracker.add_global_step()

Move data to device

116            image, mask =,

Make the gradients zero

119            self.optimizer.zero_grad()

Get predicted mask logits

121            logits = self.model(image)

Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.

124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])

Calculate loss

126            loss = self.loss_func(self.sigmoid(logits), mask)

Compute gradients

128            loss.backward()

Take an optimization step

130            self.optimizer.step()

Track the loss

132  'loss', loss)

Training loop

134    def run(self):
138        for _ in monit.loop(self.epochs):

Train the model

140            self.train()

New line in the console

142            tracker.new_line()

Save the model

144            experiment.save_checkpoint()
147def main():

Create experiment

149    experiment.create(name='unet')

Create configurations

152    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

155    experiment.configs(configs, {})


158    configs.init()

Set models for saving and loading

161    experiment.add_pytorch_models({'model': configs.model})

Start and run the training loop

164    with experiment.start():

169if __name__ == '__main__':
170    main()