U-Net実験用のカルバナデータセット

ダウンロード手順は Kaggle で確認できます

carvana/train トレーニング画像をフォルダー内に保存し、carvana/train_masks マスクをフォルダーに保存します。

16from torch import nn
17from pathlib import Path
18
19import torch.utils.data
20import torchvision.transforms.functional
21from PIL import Image
22
23from labml import lab

カーバナデータセット

26class CarvanaDataset(torch.utils.data.Dataset):
  • image_path 画像へのパスです
  • mask_path マスクへの道です
31    def __init__(self, image_path: Path, mask_path: Path):

ID で画像の辞書を取得

37        self.images = {p.stem: p for p in image_path.iterdir()}

ID でマスクの辞書を取得

39        self.masks = {p.stem[:-5]: p for p in mask_path.iterdir()}

画像IDリスト

42        self.ids = list(self.images.keys())

トランスフォーメーション

45        self.transforms = torchvision.transforms.Compose([
46            torchvision.transforms.Resize(572),
47            torchvision.transforms.ToTensor(),
48        ])

画像とそのマスクを入手してください。

  • idx は画像のインデックスです
50    def __getitem__(self, idx: int):

画像 ID を取得

58        id_ = self.ids[idx]

画像を読み込む

60        image = Image.open(self.images[id_])

画像を変換して PyTorch テンソルに変換する

62        image = self.transforms(image)

ロードマスク

64        mask = Image.open(self.masks[id_])

マスクを変換して PyTorch テンソルに変換する

66        mask = self.transforms(mask)

マスク値はなかったので、適切にスケーリングしました。

69        mask = mask / mask.max()

画像とマスクを返す

72        return image, mask

データセットのサイズ

74    def __len__(self):
78        return len(self.ids)

テストコード

82if __name__ == '__main__':
83    ds = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train', lab.get_data_path() / 'carvana' / 'train_masks')