你可以在 Kaggle 上找到下载说明。
将训练图像保存在carvana/train
文件夹中,将蒙版保存在carvana/train_masks
文件夹中。
16from torch import nn
17from pathlib import Path
18
19import torch.utils.data
20import torchvision.transforms.functional
21from PIL import Image
22
23from labml import lab
26class CarvanaDataset(torch.utils.data.Dataset):
image_path
是图像的路径mask_path
是通往口罩的路径31 def __init__(self, image_path: Path, mask_path: Path):
按 id 获取图像词典
37 self.images = {p.stem: p for p in image_path.iterdir()}
通过 id 获取口罩字典
39 self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}
映像 ID 列表
42 self.ids = list(self.images.keys())
转换
45 self.transforms = torchvision.transforms.Compose([
46 torchvision.transforms.Resize(572),
47 torchvision.transforms.ToTensor(),
48 ])
50 def __getitem__(self, idx: int):
获取图片编号
58 id_ = self.ids[idx]
加载图片
60 image = Image.open(self.images[id_])
变换图像并将其转换为 PyTorch 张量
62 image = self.transforms(image)
装载掩码
64 mask = Image.open(self.masks[id_])
变换掩码并将其转换为 PyTorch 张量
66 mask = self.transforms(mask)
掩码值不是,因此我们对其进行了适当的缩放。
69 mask = mask / mask.max()
返回图像和蒙版
72 return image, mask
74 def __len__(self):
78 return len(self.ids)
测试代码
82if __name__ == '__main__':
83 ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')