Training U-Net

This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.

Save the training images inside carvana/train folder and the masks in carvana/train_masks folder.

For simplicity, we do not do a training and validation split.

19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet

Configurations

32class Configs(BaseConfigs):

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

39    device: torch.device = DeviceConfigs()

U-Net model

42    model: UNet

Number of channels in the image. for RGB.

45    image_channels: int = 3

Number of channels in the output mask. for binary mask.

47    mask_channels: int = 1

Batch size

50    batch_size: int = 1

Learning rate

52    learning_rate: float = 2.5e-4

Number of training epochs

55    epochs: int = 4

Dataset

58    dataset: CarvanaDataset

Dataloader

60    data_loader: torch.utils.data.DataLoader

Loss function

63    loss_func = nn.BCELoss()

Sigmoid function for binary classification

65    sigmoid = nn.Sigmoid()

Adam optimizer

68    optimizer: torch.optim.Adam
70    def init(self):

Initialize the Carvana dataset

72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')

Initialize the model

75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)

Create dataloader

78        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79                                                       shuffle=True, pin_memory=True)

Create optimizer

81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

Image logging

84        tracker.set_image("sample", True)

Sample images

86    @torch.no_grad()
87    def sample(self, idx=-1):

Get a random sample

93        x, _ = self.dataset[np.random.randint(len(self.dataset))]

Move data to device

95        x = x.to(self.device)

Get predicted mask

98        mask = self.sigmoid(self.model(x[None, :]))

Crop the image to the size of the mask

100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])

Log samples

102        tracker.save('sample', x * mask)

Train for an epoch

104    def train(self):

Iterate through the dataset. Use mix to sample times per epoch.

112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):

Increment global step

114            tracker.add_global_step()

Move data to device

116            image, mask = image.to(self.device), mask.to(self.device)

Make the gradients zero

119            self.optimizer.zero_grad()

Get predicted mask logits

121            logits = self.model(image)

Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.

124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])

Calculate loss

126            loss = self.loss_func(self.sigmoid(logits), mask)

Compute gradients

128            loss.backward()

Take an optimization step

130            self.optimizer.step()

Track the loss

132            tracker.save('loss', loss)

Training loop

134    def run(self):
138        for _ in monit.loop(self.epochs):

Train the model

140            self.train()

New line in the console

142            tracker.new_line()

Save the model

144            experiment.save_checkpoint()
147def main():

Create experiment

149    experiment.create(name='unet')

Create configurations

152    configs = Configs()

Set configurations. You can override the defaults by passing the values in the dictionary.

155    experiment.configs(configs, {})

Initialize

158    configs.init()

Set models for saving and loading

161    experiment.add_pytorch_models({'model': configs.model})

Start and run the training loop

164    with experiment.start():
165        configs.run()

169if __name__ == '__main__':
170    main()