这会在 Carvana 数据集上训练一个 U-Net 模型。你可以在 Kaggle 上找到下载说明。
将训练图像保存在carvana/train
文件夹中,将蒙版保存在carvana/train_masks
文件夹中。
为简单起见,我们不进行训练和验证拆分。
19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet
32class Configs(BaseConfigs):
用于训练模型的设备。DeviceConfigs
选择可用的 CUDA 设备或默认为 CPU。
39 device: torch.device = DeviceConfigs()
图像中的通道数。对于 RGB。
45 image_channels: int = 3
输出掩码中的声道数。用于二进制掩码。
47 mask_channels: int = 1
批量大小
50 batch_size: int = 1
学习率
52 learning_rate: float = 2.5e-4
训练周期的数量
55 epochs: int = 4
数据集
58 dataset: CarvanaDataset
数据加载器
60 data_loader: torch.utils.data.DataLoader
亏损函数
63 loss_func = nn.BCELoss()
二进制分类的 Sigmoid 函数
65 sigmoid = nn.Sigmoid()
Adam 优化器
68 optimizer: torch.optim.Adam
70 def init(self):
初始化 C arvana 数据集
72 self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73 lab.get_data_path() / 'carvana' / 'train_masks')
初始化模型
75 self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
创建数据加载器
78 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79 shuffle=True, pin_memory=True)
创建优化器
81 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
图像日志记录
84 tracker.set_image("sample", True)
86 @torch.no_grad()
87 def sample(self, idx=-1):
随机获取样本
93 x, _ = self.dataset[np.random.randint(len(self.dataset))]
将数据移动到设备
95 x = x.to(self.device)
获取预测的口罩
98 mask = self.sigmoid(self.model(x[None, :]))
将图像裁剪为蒙版的大小
100 x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
日志样本
102 tracker.save('sample', x * mask)
104 def train(self):
112 for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
递增全局步长
114 tracker.add_global_step()
将数据移动到设备
116 image, mask = image.to(self.device), mask.to(self.device)
将渐变设为零
119 self.optimizer.zero_grad()
获取预测的掩码日志
121 logits = self.model(image)
将目标蒙版裁剪为 logits 的大小。如果我们不在 U-Net 的卷积层中使用填充,logits 的大小会变小。
124 mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
计算损失
126 loss = self.loss_func(self.sigmoid(logits), mask)
计算梯度
128 loss.backward()
采取优化步骤
130 self.optimizer.step()
追踪损失
132 tracker.save('loss', loss)
134 def run(self):
138 for _ in monit.loop(self.epochs):
训练模型
140 self.train()
控制台中的新行
142 tracker.new_line()
保存模型
144 experiment.save_checkpoint()
147def main():
创建实验
149 experiment.create(name='unet')
创建配置
152 configs = Configs()
设置配置。您可以通过在字典中传递值来覆盖默认值。
155 experiment.configs(configs, {})
初始化
158 configs.init()
设置用于保存和加载的模型
161 experiment.add_pytorch_models({'model': configs.model})
启动并运行训练循环
164 with experiment.start():
165 configs.run()
169if __name__ == '__main__':
170 main()