# U-Net

U-Net 由一条收缩路径和一条扩展路径组成。收缩路径是一系列卷积图层和池化图层，其中要素地图的分辨率会逐渐降低。扩展路径是一系列向上采样图层和卷积图层，其中要素地图的分辨率会逐渐提高。

27import torch
28import torchvision.transforms.functional
29from torch import nn

### 两个卷积层

32class DoubleConvolution(nn.Module):
• in_channels 是输入声道的数量
• out_channels 是输出声道的数量
43    def __init__(self, in_channels: int, out_channels: int):
48        super().__init__()

51        self.first = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
52        self.act1 = nn.ReLU()

54        self.second = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
55        self.act2 = nn.ReLU()
57    def forward(self, x: torch.Tensor):

59        x = self.first(x)
60        x = self.act1(x)
61        x = self.second(x)
62        return self.act2(x)

### 向下采样

65class DownSample(nn.Module):
73    def __init__(self):
74        super().__init__()

76        self.pool = nn.MaxPool2d(2)
78    def forward(self, x: torch.Tensor):
79        return self.pool(x)

### 向上采样

82class UpSample(nn.Module):
89    def __init__(self, in_channels: int, out_channels: int):
90        super().__init__()

93        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
95    def forward(self, x: torch.Tensor):
96        return self.up(x)

### 裁剪并连接要素地图

99class CropAndConcat(nn.Module):
• x 扩展路径中的当前要素地图
• contracting_x 收缩路径中的相应要素地图
106    def forward(self, x: torch.Tensor, contracting_x: torch.Tensor):

113        contracting_x = torchvision.transforms.functional.center_crop(contracting_x, [x.shape[2], x.shape[3]])

115        x = torch.cat([x, contracting_x], dim=1)
117        return x

## U-Net

120class UNet(nn.Module):
• in_channels 输入图像中的通道数
• out_channels 结果特征图中的信道数
124    def __init__(self, in_channels: int, out_channels: int):
129        super().__init__()

133        self.down_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
134                                        [(in_channels, 64), (64, 128), (128, 256), (256, 512)]])

136        self.down_sample = nn.ModuleList([DownSample() for _ in range(4)])

139        self.middle_conv = DoubleConvolution(512, 1024)

143        self.up_sample = nn.ModuleList([UpSample(i, o) for i, o in
144                                        [(1024, 512), (512, 256), (256, 128), (128, 64)]])

149        self.up_conv = nn.ModuleList([DoubleConvolution(i, o) for i, o in
150                                      [(1024, 512), (512, 256), (256, 128), (128, 64)]])

152        self.concat = nn.ModuleList([CropAndConcat() for _ in range(4)])

154        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
• x 输入图像
156    def forward(self, x: torch.Tensor):

161        pass_through = []

163        for i in range(len(self.down_conv)):

165            x = self.down_conv[i](x)

167            pass_through.append(x)

169            x = self.down_sample[i](x)

U-Net 底部有两个卷积层

172        x = self.middle_conv(x)

175        for i in range(len(self.up_conv)):

177            x = self.up_sample[i](x)

179            x = self.concat[i](x, pass_through.pop())

181            x = self.up_conv[i](x)

184        x = self.final_conv(x)
187        return x