This trains a U-Net model on Carvana dataset. You can find the download instructions on Kaggle.
Save the training images inside carvana/train
folder and the masks in carvana/train_masks
folder.
For simplicity, we do not do a training and validation split.
19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet
32class Configs(BaseConfigs):
Device to train the model on. DeviceConfigs
picks up an available CUDA device or defaults to CPU.
39 device: torch.device = DeviceConfigs()
Number of channels in the image. for RGB.
45 image_channels: int = 3
Number of channels in the output mask. for binary mask.
47 mask_channels: int = 1
Batch size
50 batch_size: int = 1
Learning rate
52 learning_rate: float = 2.5e-4
Number of training epochs
55 epochs: int = 4
Dataset
58 dataset: CarvanaDataset
Dataloader
60 data_loader: torch.utils.data.DataLoader
Loss function
63 loss_func = nn.BCELoss()
Sigmoid function for binary classification
65 sigmoid = nn.Sigmoid()
Adam optimizer
68 optimizer: torch.optim.Adam
70 def init(self):
Initialize the Carvana dataset
72 self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73 lab.get_data_path() / 'carvana' / 'train_masks')
Initialize the model
75 self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
Create dataloader
78 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79 shuffle=True, pin_memory=True)
Create optimizer
81 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
Image logging
84 tracker.set_image("sample", True)
86 @torch.no_grad()
87 def sample(self, idx=-1):
Get a random sample
93 x, _ = self.dataset[np.random.randint(len(self.dataset))]
Move data to device
95 x = x.to(self.device)
Get predicted mask
98 mask = self.sigmoid(self.model(x[None, :]))
Crop the image to the size of the mask
100 x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
Log samples
102 tracker.save('sample', x * mask)
104 def train(self):
112 for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
Increment global step
114 tracker.add_global_step()
Move data to device
116 image, mask = image.to(self.device), mask.to(self.device)
Make the gradients zero
119 self.optimizer.zero_grad()
Get predicted mask logits
121 logits = self.model(image)
Crop the target mask to the size of the logits. Size of the logits will be smaller if we don't use padding in convolutional layers in the U-Net.
124 mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
Calculate loss
126 loss = self.loss_func(self.sigmoid(logits), mask)
Compute gradients
128 loss.backward()
Take an optimization step
130 self.optimizer.step()
Track the loss
132 tracker.save('loss', loss)
134 def run(self):
138 for _ in monit.loop(self.epochs):
Train the model
140 self.train()
New line in the console
142 tracker.new_line()
Save the model
144 experiment.save_checkpoint()
147def main():
Create experiment
149 experiment.create(name='unet')
Create configurations
152 configs = Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
155 experiment.configs(configs, {})
Initialize
158 configs.init()
Set models for saving and loading
161 experiment.add_pytorch_models({'model': configs.model})
Start and run the training loop
164 with experiment.start():
165 configs.run()
169if __name__ == '__main__':
170 main()