ユーネット

これは、論文「U-Net: 生物医学画像セグメンテーションのための畳み込みネットワーク」のU-Netモデルの実装です

U-Netは収縮経路と拡張経路で構成されています。収縮経路は一連の畳み込み層とプーリング層であり、特徴マップの解像度は徐々に低下します。エクスパンシブパスとは、フィーチャマップの解像度が徐々に上がっていく一連のアップサンプリングレイヤーと畳み込みレイヤーのことです

拡張パスの各ステップで、縮小パスからの対応するフィーチャマップが現在のフィーチャマップと連結されます。

U-Net diagram from paper

これは、CarvanaデータセットでU-Netをトレーニングする実験のトレーニングコードです

27import torch
28import torchvision.transforms.functional
29from torch import nn

2 つのコンボリューションレイヤー

収縮経路と膨張経路の各ステップには、 2つの畳み込み層があり、その後にReLU活性化が続きます。

U-Netの論文ではパディングを使用していましたが、最終的なフィーチャマップがトリミングされないようにパディングを使用しています。

32class DoubleConvolution(nn.Module):
  • in_channels は入力チャンネル数
  • out_channels は出力チャンネル数
43    def __init__(self, in_channels: int, out_channels: int):
48        super().__init__()

最初の畳み込み層

51        self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
52        self.act1 = nn.ReLU()

2 番目の畳み込み層

54        self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
55        self.act2 = nn.ReLU()
57    def forward(self, x: torch.Tensor):

2 つのコンボリューションレイヤーとアクティベーションを適用します。

59        x = self.first(x)
60        x = self.act1(x)
61        x = self.second(x)
62        return self.act2(x)

ダウンサンプル

収縮パスの各ステップは、最大プーリングレイヤーを使用して特徴マップをダウンサンプリングします。

65class DownSample(nn.Module):
73    def __init__(self):
74        super().__init__()

最大プーリング層

76        self.pool = nn.MaxPool2d(2)
78    def forward(self, x: torch.Tensor):
79        return self.pool(x)

アップサンプル

広大な経路の各ステップは、アップコンボリューションで特徴マップをアップサンプリングします。

82class UpSample(nn.Module):
89    def __init__(self, in_channels: int, out_channels: int):
90        super().__init__()

アップコンボリューション

93        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
95    def forward(self, x: torch.Tensor):
96        return self.up(x)

フィーチャマップのトリミングと連結

拡張パスの各ステップで、縮小パスからの対応するフィーチャマップが現在のフィーチャマップと連結されます。

99class CropAndConcat(nn.Module):
  • x エクスパンシブパス内の現在のフィーチャマップ
  • contracting_x 契約経路からの対応する機能マップ
106    def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):

フィーチャマップを縮小パスから現在のフィーチャマップのサイズにトリミングします

113        contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])

フィーチャマップを連結する

115        x = torch.cat([x, contracting_x], dim=1)

117        return x

ユーネット

120class UNet(nn.Module):
  • in_channels 入力画像のチャンネル数
  • out_channels 結果フィーチャマップのチャネル数
124    def __init__(self, in_channels: int, out_channels: int):
129        super().__init__()

収縮経路用の二重畳み込み層。から始まる各ステップで機能の数が 2 倍になります

133        self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
134                                        [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])

収縮経路用のダウンサンプリング層

136        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])

最も低い解像度の 2 つの畳み込み層 (U の下部)。

139        self.middle_conv = DoubleConvolution(512, 1024)

広大な経路のサンプリングレイヤーをアップサンプリング。アップサンプリングを行うと、機能の数は半分になります

143        self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in
144                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]])

膨張経路用の二重畳み込み層。それらの入力は、現在のフィーチャマップと縮小パスからのフィーチャマップを連結したものです。したがって、入力フィーチャの数は、アップサンプリングによるフィーチャ数の2倍になります

149        self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
150                                      [(1024, 512), (512, 256), (256, 128), (128, 64)]])

レイヤーをトリミングして連結すると、広大なパスが作れます。

152        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

出力を生成する最後の畳み込み層

154        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
  • x 入力画像
156    def forward(self, x: torch.Tensor):

縮小パスの出力を収集して、後で拡張パスと連結できるようにします。

161        pass_through = []

契約経路

163        for i in range(len(self.down_conv)):

2 つの畳み込み層

165            x = self.down_conv[i](x)

アウトプットを収集

167            pass_through.append(x)

ダウンサンプル

169            x = self.down_sample[i](x)

U-Netの下部にある2つの畳み込み層

172        x = self.middle_conv(x)

広大な道

175        for i in range(len(self.up_conv)):

アップサンプル

177            x = self.up_sample[i](x)

コントラクトパスの出力を連結する

179            x = self.concat[i](x, pass_through.pop())

2 つの畳み込み層

181            x = self.up_conv[i](x)

最終畳み込み層

184        x = self.final_conv(x)

187        return x