训练 U-Net

这会在 Carvana 数据集上训练一个 U-Net 模型。你可以在 Kaggle 上找到下载说明。

将训练图像保存在carvana/train 文件夹中,将蒙版保存在carvana/train_masks 文件夹中。

为简单起见,我们不进行训练和验证拆分。

19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet

配置

32class Configs(BaseConfigs):

用于训练模型的设备。DeviceConfigs 选择可用的 CUDA 设备或默认为 CPU。

39    device: torch.device = DeviceConfigs()

U-Net 模型

42    model: UNet

图像中的通道数。对于 RGB。

45    image_channels: int = 3

输出掩码中的声道数。用于二进制掩码。

47    mask_channels: int = 1

批量大小

50    batch_size: int = 1

学习率

52    learning_rate: float = 2.5e-4

训练周期的数量

55    epochs: int = 4

数据集

58    dataset: CarvanaDataset

数据加载器

60    data_loader: torch.utils.data.DataLoader

亏损函数

63    loss_func = nn.BCELoss()

二进制分类的 Sigmoid 函数

65    sigmoid = nn.Sigmoid()

Adam 优化器

68    optimizer: torch.optim.Adam
70    def init(self):

初始化 C arvana 数据集

72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')

初始化模型

75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)

创建数据加载器

78        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79                                                       shuffle=True, pin_memory=True)

创建优化器

81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

图像日志记录

84        tracker.set_image("sample", True)

样本图片

86    @torch.no_grad()
87    def sample(self, idx=-1):

随机获取样本

93        x, _ = self.dataset[np.random.randint(len(self.dataset))]

将数据移动到设备

95        x = x.to(self.device)

获取预测的口罩

98        mask = self.sigmoid(self.model(x[None, :]))

将图像裁剪为蒙版的大小

100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])

日志样本

102        tracker.save('sample', x * mask)

训练一个时代

104    def train(self):

遍历数据集。用于mix 对每个纪元的采样次数。

112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):

递增全局步长

114            tracker.add_global_step()

将数据移动到设备

116            image, mask = image.to(self.device), mask.to(self.device)

将渐变设为零

119            self.optimizer.zero_grad()

获取预测的掩码日志

121            logits = self.model(image)

将目标蒙版裁剪为 logits 的大小。如果我们不在 U-Net 的卷积层中使用填充,logits 的大小会变小。

124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])

计算损失

126            loss = self.loss_func(self.sigmoid(logits), mask)

计算梯度

128            loss.backward()

采取优化步骤

130            self.optimizer.step()

追踪损失

132            tracker.save('loss', loss)

训练循环

134    def run(self):
138        for _ in monit.loop(self.epochs):

训练模型

140            self.train()

控制台中的新行

142            tracker.new_line()

保存模型

144            experiment.save_checkpoint()
147def main():

创建实验

149    experiment.create(name='unet')

创建配置

152    configs = Configs()

设置配置。您可以通过在字典中传递值来覆盖默认值。

155    experiment.configs(configs, {})

初始化

158    configs.init()

设置用于保存和加载的模型

161    experiment.add_pytorch_models({'model': configs.model})

启动并运行训练循环

164    with experiment.start():
165        configs.run()

169if __name__ == '__main__':
170    main()