これは、論文「U-Net: 生物医学画像セグメンテーションのための畳み込みネットワーク」のU-Netモデルの実装です。
U-Netは収縮経路と拡張経路で構成されています。収縮経路は一連の畳み込み層とプーリング層であり、特徴マップの解像度は徐々に低下します。エクスパンシブパスとは、フィーチャマップの解像度が徐々に上がっていく一連のアップサンプリングレイヤーと畳み込みレイヤーのことです
。拡張パスの各ステップで、縮小パスからの対応するフィーチャマップが現在のフィーチャマップと連結されます。
27import torch
28import torchvision.transforms.functional
29from torch import nn
収縮経路と膨張経路の各ステップには、 2つの畳み込み層があり、その後にReLU活性化が続きます。
U-Netの論文ではパディングを使用していましたが、最終的なフィーチャマップがトリミングされないようにパディングを使用しています。
32class DoubleConvolution(nn.Module):
in_channels
は入力チャンネル数out_channels
は出力チャンネル数43 def __init__(self, in_channels: int, out_channels: int):
48 super().__init__()
最初の畳み込み層
51 self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
52 self.act1 = nn.ReLU()
2 番目の畳み込み層
54 self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
55 self.act2 = nn.ReLU()
57 def forward(self, x: torch.Tensor):
2 つのコンボリューションレイヤーとアクティベーションを適用します。
59 x = self.first(x)
60 x = self.act1(x)
61 x = self.second(x)
62 return self.act2(x)
65class DownSample(nn.Module):
73 def __init__(self):
74 super().__init__()
最大プーリング層
76 self.pool = nn.MaxPool2d(2)
78 def forward(self, x: torch.Tensor):
79 return self.pool(x)
82class UpSample(nn.Module):
89 def __init__(self, in_channels: int, out_channels: int):
90 super().__init__()
アップコンボリューション
93 self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
95 def forward(self, x: torch.Tensor):
96 return self.up(x)
99class CropAndConcat(nn.Module):
x
エクスパンシブパス内の現在のフィーチャマップcontracting_x
契約経路からの対応する機能マップ106 def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):
フィーチャマップを縮小パスから現在のフィーチャマップのサイズにトリミングします
113 contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])
フィーチャマップを連結する
115 x = torch.cat([x, contracting_x], dim=1)
117 return x
120class UNet(nn.Module):
in_channels
入力画像のチャンネル数out_channels
結果フィーチャマップのチャネル数124 def __init__(self, in_channels: int, out_channels: int):
129 super().__init__()
収縮経路用の二重畳み込み層。から始まる各ステップで機能の数が 2 倍になります
。133 self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
134 [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])
収縮経路用のダウンサンプリング層
136 self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])
最も低い解像度の 2 つの畳み込み層 (U の下部)。
139 self.middle_conv = DoubleConvolution(512, 1024)
広大な経路のサンプリングレイヤーをアップサンプリング。アップサンプリングを行うと、機能の数は半分になります
。143 self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in
144 [(1024, 512), (512, 256), (256, 128), (128, 64)]])
膨張経路用の二重畳み込み層。それらの入力は、現在のフィーチャマップと縮小パスからのフィーチャマップを連結したものです。したがって、入力フィーチャの数は、アップサンプリングによるフィーチャ数の2倍になります
。149 self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
150 [(1024, 512), (512, 256), (256, 128), (128, 64)]])
レイヤーをトリミングして連結すると、広大なパスが作れます。
152 self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])
出力を生成する最後の畳み込み層
154 self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
x
入力画像156 def forward(self, x: torch.Tensor):
縮小パスの出力を収集して、後で拡張パスと連結できるようにします。
161 pass_through = []
契約経路
163 for i in range(len(self.down_conv)):
2 つの畳み込み層
165 x = self.down_conv[i](x)
アウトプットを収集
167 pass_through.append(x)
ダウンサンプル
169 x = self.down_sample[i](x)
U-Netの下部にある2つの畳み込み層
172 x = self.middle_conv(x)
広大な道
175 for i in range(len(self.up_conv)):
アップサンプル
177 x = self.up_sample[i](x)
コントラクトパスの出力を連結する
179 x = self.concat[i](x, pass_through.pop())
2 つの畳み込み層
181 x = self.up_conv[i](x)
最終畳み込み層
184 x = self.final_conv(x)
187 return x