ダウンロード手順は Kaggle で確認できます。
carvana/train
トレーニング画像をフォルダー内に保存し、carvana/train_masks
マスクをフォルダーに保存します。
16from torch import nn
17from pathlib import Path
18
19import torch.utils.data
20import torchvision.transforms.functional
21from PIL import Image
22
23from labml import lab
26class CarvanaDataset(torch.utils.data.Dataset):
image_path
画像へのパスですmask_path
マスクへの道です31 def __init__(self, image_path: Path, mask_path: Path):
ID で画像の辞書を取得
37 self.images = {p.stem: p for p in image_path.iterdir()}
ID でマスクの辞書を取得
39 self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}
画像IDリスト
42 self.ids = list(self.images.keys())
トランスフォーメーション
45 self.transforms = torchvision.transforms.Compose([
46 torchvision.transforms.Resize(572),
47 torchvision.transforms.ToTensor(),
48 ])
50 def __getitem__(self, idx: int):
画像 ID を取得
58 id_ = self.ids[idx]
画像を読み込む
60 image = Image.open(self.images[id_])
画像を変換して PyTorch テンソルに変換する
62 image = self.transforms(image)
ロードマスク
64 mask = Image.open(self.masks[id_])
マスクを変換して PyTorch テンソルに変換する
66 mask = self.transforms(mask)
マスク値はなかったので、適切にスケーリングしました。
69 mask = mask / mask.max()
画像とマスクを返す
72 return image, mask
74 def __len__(self):
78 return len(self.ids)
テストコード
82if __name__ == '__main__':
83 ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')