U-Net 实验的 Carvana 数据集

你可以在 Kaggle 上找到下载说明。

将训练图像保存在carvana/train 文件夹中,将蒙版保存在carvana/train_masks 文件夹中。

16from torch import nn
17from pathlib import Path
18
19import torch.utils.data
20import torchvision.transforms.functional
21from PIL import Image
22
23from labml import lab

Carvana 数据集

26class CarvanaDataset(torch.utils.data.Dataset):
  • image_path 是图像的路径
  • mask_path 是通往口罩的路径
31    def __init__(self, image_path: Path, mask_path: Path):

按 id 获取图像词典

37        self.images = {p.stem: p for p in image_path.iterdir()}

通过 id 获取口罩字典

39        self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}

映像 ID 列表

42        self.ids = list(self.images.keys())

转换

45        self.transforms = torchvision.transforms.Compose([
46            torchvision.transforms.Resize(572),
47            torchvision.transforms.ToTensor(),
48        ])

获取图像及其蒙版。

  • idx 是图像的索引
50    def __getitem__(self, idx: int):

获取图片编号

58        id_ = self.ids[idx]

加载图片

60        image = Image.open(self.images[id_])

变换图像并将其转换为 PyTorch 张量

62        image = self.transforms(image)

装载掩码

64        mask = Image.open(self.masks[id_])

变换掩码并将其转换为 PyTorch 张量

66        mask = self.transforms(mask)

掩码值不是,因此我们对其进行了适当的缩放。

69        mask = mask / mask.max()

返回图像和蒙版

72        return image, mask

数据集的大小

74    def __len__(self):
78        return len(self.ids)

测试代码

82if __name__ == '__main__':
83    ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')