これにより、CarvanaデータセットでU-Netモデルをトレーニングします。ダウンロード手順は Kaggle で確認できます
。carvana/train
トレーニング画像をフォルダー内に保存し、carvana/train_masks
マスクをフォルダーに保存します。
わかりやすくするために、トレーニングと検証の分割は行っていません。
19import numpy as np
20import torch
21import torch.utils.data
22import torchvision.transforms.functional
23from torch import nn
24
25from labml import lab, tracker, experiment, monit
26from labml.configs import BaseConfigs
27from labml_helpers.device import DeviceConfigs
28from labml_nn.unet.carvana import CarvanaDataset
29from labml_nn.unet import UNet
32class Configs(BaseConfigs):
モデルをトレーニングするデバイス。DeviceConfigs
使用可能な CUDA デバイスを選択するか、デフォルトで CPU に設定します
39 device: torch.device = DeviceConfigs()
画像内のチャンネル数。RGB 用です。
45 image_channels: int = 3
出力マスクのチャンネル数。バイナリマスク用。
47 mask_channels: int = 1
バッチサイズ
50 batch_size: int = 1
学習率
52 learning_rate: float = 2.5e-4
トレーニングエポックの数
55 epochs: int = 4
データセット
58 dataset: CarvanaDataset
データローダー
60 data_loader: torch.utils.data.DataLoader
損失関数
63 loss_func = nn.BCELoss()
バイナリ分類用のシグモイド関数
65 sigmoid = nn.Sigmoid()
アダム・オプティマイザー
68 optimizer: torch.optim.Adam
70 def init(self):
72 self.dataset = CarvanaDataset(lab.get_data_path() / 'carvana' / 'train',
73 lab.get_data_path() / 'carvana' / 'train_masks')
モデルを初期化
75 self.model = UNet(self.image_channels, self.mask_channels).to(self.device)
データローダーの作成
78 self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size,
79 shuffle=True, pin_memory=True)
オプティマイザーを作成
81 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
画像ロギング
84 tracker.set_image("sample", True)
86 @torch.no_grad()
87 def sample(self, idx=-1):
ランダムサンプルを入手
93 x, _ = self.dataset[np.random.randint(len(self.dataset))]
データをデバイスに移動
95 x = x.to(self.device)
予測マスクを取得
98 mask = self.sigmoid(self.model(x[None, :]))
画像をマスクのサイズにトリミングします
100 x = torchvision.transforms.functional.center_crop(x, [mask.shape[2], mask.shape[3]])
ログサンプル
102 tracker.save('sample', x * mask)
104 def train(self):
112 for _, (image, mask) in monit.mix(('Train', self.data_loader), (self.sample, list(range(50)))):
グローバルステップをインクリメント
114 tracker.add_global_step()
データをデバイスに移動
116 image, mask = image.to(self.device), mask.to(self.device)
グラデーションをゼロにする
119 self.optimizer.zero_grad()
予測されたマスクロジットの取得
121 logits = self.model(image)
ターゲットマスクをロジットのサイズにトリミングします。U-Netの畳み込み層にパディングを使わないと、ロジットのサイズは小さくなります
。124 mask = torchvision.transforms.functional.center_crop(mask, [logits.shape[2], logits.shape[3]])
損失の計算
126 loss = self.loss_func(self.sigmoid(logits), mask)
勾配の計算
128 loss.backward()
最適化の一歩を踏み出す
130 self.optimizer.step()
損失をトラッキング
132 tracker.save('loss', loss)
134 def run(self):
138 for _ in monit.loop(self.epochs):
モデルのトレーニング
140 self.train()
コンソールの新しい行
142 tracker.new_line()
モデルを保存する
144 experiment.save_checkpoint()
147def main():
実験を作成
149 experiment.create(name='unet')
構成の作成
152 configs = Configs()
構成を設定します。ディクショナリに値を渡すことでデフォルトをオーバーライドできます。
155 experiment.configs(configs, {})
[初期化]
158 configs.init()
保存および読み込み用のモデルを設定する
161 experiment.add_pytorch_models({'model': configs.model})
トレーニングループを開始して実行する
164 with experiment.start():
165 configs.run()
169if __name__ == '__main__':
170 main()