トレーニング U-ネット

これにより、CarvanaデータセットでU-Netモデルをトレーニングします。ダウンロード手順は Kaggle で確認できます

carvana/train トレーニング画像をフォルダー内に保存し、carvana/train_masks マスクをフォルダーに保存します。

わかりやすくするために、トレーニングと検証の分割は行っていません。

19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet

コンフィギュレーション

32class Configs(BaseConfigs):

モデルをトレーニングするデバイス。DeviceConfigs 使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します

39    device: torch.device = DeviceConfigs()
42    model: UNet

画像内のチャンネル数。RGB 用です。

45    image_channels: int = 3

出力マスクのチャンネル数。バイナリマスク用。

47    mask_channels: int = 1

バッチサイズ

50    batch_size: int = 1

学習率

52    learning_rate: float = 2.5e-4

トレーニングエポックの数

55    epochs: int = 4

データセット

58    dataset: CarvanaDataset

データローダー

60    data_loader: torch.utils.data.DataLoader

損失関数

63    loss_func = nn.BCELoss()

バイナリ分類用のシグモイド関数

65    sigmoid = nn.Sigmoid()

アダム・オプティマイザー

68    optimizer: torch.optim.Adam
70    def init(self):

Carvana データセットを初期化します

72        self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73                                      lab.get_data_path() / 'carvana' / 'train_masks')

モデルを初期化

75        self.model = UNet(self.image_channels, self.mask_channels).to(self.device)

データローダーの作成

78        self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79                                                       shuffle=True, pin_memory=True)

オプティマイザーを作成

81        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

画像ロギング

84        tracker.set_image("sample", True)

サンプル画像

86    @torch.no_grad()
87    def sample(self, idx=-1):

ランダムサンプルを入手

93        x, _ = self.dataset[np.random.randint(len(self.dataset))]

データをデバイスに移動

95        x = x.to(self.device)

予測マスクを取得

98        mask = self.sigmoid(self.model(x[None, :]))

画像をマスクのサイズにトリミングします

100        x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])

ログサンプル

102        tracker.save('sample', x * mask)

一時代を拓く列車

104    def train(self):

データセットを繰り返し処理します。mix エポックあたりのサンプリング時間に使用します

112        for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):

グローバルステップをインクリメント

114            tracker.add_global_step()

データをデバイスに移動

116            image, mask = image.to(self.device), mask.to(self.device)

グラデーションをゼロにする

119            self.optimizer.zero_grad()

予測されたマスクロジットの取得

121            logits = self.model(image)

ターゲットマスクをロジットのサイズにトリミングします。U-Netの畳み込み層にパディングを使わないと、ロジットのサイズは小さくなります

124            mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])

損失の計算

126            loss = self.loss_func(self.sigmoid(logits), mask)

勾配の計算

128            loss.backward()

最適化の一歩を踏み出す

130            self.optimizer.step()

損失をトラッキング

132            tracker.save('loss', loss)

トレーニングループ

134    def run(self):
138        for _ in monit.loop(self.epochs):

モデルのトレーニング

140            self.train()

コンソールの新しい行

142            tracker.new_line()

モデルを保存する

144            experiment.save_checkpoint()
147def main():

実験を作成

149    experiment.create(name='unet')

構成の作成

152    configs = Configs()

構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。

155    experiment.configs(configs, {})

[初期化]

158    configs.init()

保存および読み込み用のモデルを設定する

161    experiment.add_pytorch_models({'model': configs.model})

トレーニングループを開始して実行する

164    with experiment.start():
165        configs.run()

169if __name__ == '__main__':
170    main()