画像認識のための深層残差学習 (ResNet)

これは、「画像認識のための深層残差学習という論文をPyTorchで実装したものです

ResNetは劣化の問題を克服するために層を残差関数として学習させます。劣化の問題は、層の数が非常に多くなると、ディープニューラルネットワークの精度が低下することです。レイヤーの数が増えると精度が上がり、飽和し、劣化が始まります

この論文では、余分な層はアイデンティティマッピングの実行方法を学習するだけでよいため、より深いモデルは少なくとも浅いモデルと同様に機能すべきだと主張しています。

残余学習

数層で学習する必要があるマッピングであれば、残差関数をトレーニングします

代わりに。そして本来の機能は次のようになります

この場合、アイデンティティマッピングの学習は、「なりたい」ことを学ぶことと同じで、学習しやすくなります。

パラメータ化された形式では、次のように記述できます。

また、特徴マップのサイズとが異なる場合は、学習した重みを使用して線形投影を行うことを提案しています。

Paper では、線形投影の代わりにゼロパディングを試したところ、線形投影の方が効果的であることがわかりました。また、フィーチャマップのサイズが一致する場合、線形投影よりもアイデンティティマッピングの方が優れていることがわかりました

複数のレイヤーが必要です。そうでない場合、合計にも非線形性がなく、線形レイヤーのようになります。

CIFAR-10でResNetをトレーニングするためのトレーニングコードは次のとおりです

55from typing import List, Optional
56
57import torch
58from torch import nn
59
60from labml_helpers.module import Module

ショートカット接続用の線形投影

これは上記の投影を行います。

63class ShortcutProjection(Module):
  • in_channels は内のチャンネル数
  • out_channels は内のチャンネル数
  • stride はの畳み込み演算におけるストライドの長さです。フィーチャマップのサイズに合わせて、ショートカット接続でも同じ手順を実行します
70    def __init__(self, in_channels: int, out_channels: int, stride: int):
77        super().__init__()

線形投影用のコンボリューションレイヤー

80        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

論文では、各畳み込み演算の後にバッチ正規化を追加することを提案しています

82        self.bn = nn.BatchNorm2d(out_channels)
84    def forward(self, x: torch.Tensor):

コンボリューションとバッチ正規化

86        return self.bn(self.conv(x))

残余ブロック

これは、論文で説明した残留ブロックを実装したものです。畳み込み層が 2 つあります

Residual Block

out_channels 最初の畳み込み層はout_channelsin_channels からにマッピングされます。この方が、in_channels 特徴マップのサイズを小さくしてストライドの長さがより大きくなる場合よりも大きくなります。

2 out_channels out_channels 番目の畳み込み層はからにマップされ、ストライドの長さは常に 1 です。

両方の畳み込み層の後にバッチ正規化が行われます。

89class ResidualBlock(Module):
  • in_channels は内のチャンネル数
  • out_channels は出力チャンネル数
  • stride はコンボリューション演算のストライドの長さです。
  • 110    def __init__(self, in_channels: int, out_channels: int, stride: int):
    116        super().__init__()

    最初の畳み込みレイヤー、これは次のようにマッピングされます out_channels

    119        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)

    最初の畳み込み後のバッチ正規化

    121        self.bn1 = nn.BatchNorm2d(out_channels)

    最初のアクティベーション機能 (ReLU)

    123        self.act1 = nn.ReLU()

    2 番目の畳み込み層

    126        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    2 回目の畳み込み後のバッチ正規化

    128        self.bn2 = nn.BatchNorm2d(out_channels)

    ショートカットの接続は、ストライドの長さが合わない場合はプロジェクションにしてください。チャンネル数が変わると

    132        if stride != 1 or in_channels != out_channels:

    プロジェクション

    134            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
    135        else:

    アイデンティティ

    137            self.shortcut = nn.Identity()

    2 回目の起動機能 (ReLU) (ショートカット追加後)

    140        self.act2 = nn.ReLU()
    • x 形状の入力です [batch_size, in_channels, height, width]
    142    def forward(self, x: torch.Tensor):

    ショートカット接続を取得

    147        shortcut = self.shortcut(x)

    最初のコンボリューションとアクティベーション

    149        x = self.act1(self.bn1(self.conv1(x)))

    2 回目の畳み込み

    151        x = self.bn2(self.conv2(x))

    ショートカット追加後のアクティベーション機能

    153        return self.act2(x + shortcut)

    ボトルネック残留ブロック

    これにより、論文で説明されているボトルネックブロックが実装されます。および畳み込み層があります

    Bottlenext Block

    最初の畳み込み層は、in_channels bottleneck_channels からへの畳み込みでマッピングされます。ここで、bottleneck_channels はよりも低くなります。in_channels

    2 bottleneck_channels 番目の畳み込み層はからにマップされます。bottleneck_channels これにより、フィーチャマップのサイズを圧縮したい場合よりもストライドの長さが大きくなる可能性があります

    3 番目の最後の畳み込み層はにマッピングされます。out_channels out_channels in_channels ストライドの長さがより大きい場合よりも大きく、それ以外の場合はと等しい

    in_channels

    bottleneck_channels in_channels がよりも小さく、畳み込みがこの縮小されたスペースで実行されます(したがってボトルネックになります)。2 つの畳み込みによってチャネル数が減少し、増加します

    156class BottleneckResidualBlock(Module):
    • in_channels は内のチャンネル数
    • bottleneck_channels はコンボリューションのチャネル数です
    • out_channels は出力チャンネル数
  • stride はコンボリューション演算のストライドの長さです。
  • 184    def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
    191        super().__init__()

    最初の畳み込みレイヤー、これは次のようにマッピングされます bottleneck_channels

    194        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)

    最初の畳み込み後のバッチ正規化

    196        self.bn1 = nn.BatchNorm2d(bottleneck_channels)

    最初のアクティベーション機能 (ReLU)

    198        self.act1 = nn.ReLU()

    2 番目の畳み込み層

    201        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)

    2 回目の畳み込み後のバッチ正規化

    203        self.bn2 = nn.BatchNorm2d(bottleneck_channels)

    2 番目のアクティベーション機能 (ReLU)

    205        self.act2 = nn.ReLU()

    3 番目の畳み込み層、これはにマップされます。out_channels

    208        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)

    2 回目の畳み込み後のバッチ正規化

    210        self.bn3 = nn.BatchNorm2d(out_channels)

    ショートカットの接続は、ストライドの長さが合わない場合はプロジェクションにしてください。チャンネル数が変わると

    214        if stride != 1 or in_channels != out_channels:

    プロジェクション

    216            self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
    217        else:

    アイデンティティ

    219            self.shortcut = nn.Identity()

    2 回目の起動機能 (ReLU) (ショートカット追加後)

    222        self.act3 = nn.ReLU()
    • x 形状の入力です [batch_size, in_channels, height, width]
    224    def forward(self, x: torch.Tensor):

    ショートカット接続を取得

    229        shortcut = self.shortcut(x)

    最初のコンボリューションとアクティベーション

    231        x = self.act1(self.bn1(self.conv1(x)))

    2 回目のコンボリューションとアクティベーション

    233        x = self.act2(self.bn2(self.conv2(x)))

    3 番目の畳み込み

    235        x = self.bn3(self.conv3(x))

    ショートカット追加後のアクティベーション機能

    237        return self.act3(x + shortcut)

    リネットモデル

    これは最後の線形層と分類用のソフトマックスを含まないresnetモデルの基本です。

    再ネットは、積み重ねられた残留ブロックまたはボトルネックの残留ブロックで構成されています。フィーチャマップのサイズは、ストライドの長さのブロックで数ブロック進むと半分になります。フィーチャマップのサイズを小さくすると、チャネル数が増えます。最後に、特徴マップを平均してベクトル表現を求めます

    240class ResNetBase(Module):
    • n_blocks は、各フィーチャマップサイズのブロック数のリストです。
    • n_channels は、各フィーチャマップサイズのチャネル数です。
    • bottlenecks ボトルネックとなるチャネル数です。その場合None残留ブロックが使用されます
    • img_channels は入力のチャンネル数です。
    • first_kernel_size は初期畳み込み層のカーネルサイズ
    254    def __init__(self, n_blocks: List[int], n_channels: List[int],
    255                 bottlenecks: Optional[List[int]] = None,
    256                 img_channels: int = 3, first_kernel_size: int = 7):
    265        super().__init__()

    各フィーチャマップサイズのブロック数とチャネル数

    268        assert len(n_blocks) == len(n_channels)

    ボトルネックの残留ブロックを使用する場合は、機能マップのサイズごとにボトルネックのチャネル数を指定する必要があります。

    271        assert bottlenecks is None or len(bottlenecks) == len(n_channels)

    初期畳み込み層は、img_channels 最初の残差ブロックのチャネル数からにマッピングされます () n_channels[0]

    275        self.conv = nn.Conv2d(img_channels, n_channels[0],
    276                              kernel_size=first_kernel_size, stride=2, padding=first_kernel_size // 2)

    初期畳み込み後のバッチノルム

    278        self.bn = nn.BatchNorm2d(n_channels[0])

    ブロック一覧

    281        blocks = []

    前のレイヤー (またはブロック) のチャンネル数

    283        prev_channels = n_channels[0]

    各フィーチャマップサイズをループスループ

    285        for i, channels in enumerate(n_channels):

    新しいフィーチャマップサイズの最初のブロックのストライドの長さは、一番最初のブロックを除いてです。

    288            stride = 2 if len(blocks) == 0 else 1
    289
    290            if bottlenecks is None:
    292                blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
    293            else:
    296                blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels,
    297                                                      stride=stride))

    チャンネル数を変更

    300            prev_channels = channels

    残りのブロックを追加-フィーチャマップのサイズやチャンネルは変更なし

    302            for _ in range(n_blocks[i] - 1):
    303                if bottlenecks is None:
    305                    blocks.append(ResidualBlock(channels, channels, stride=1))
    306                else:
    308                    blocks.append(BottleneckResidualBlock(channels, bottlenecks[i], channels, stride=1))

    ブロックを積み重ねよう

    311        self.blocks = nn.Sequential(*blocks)
    • x 形がある [batch_size, img_channels, height, width]
    313    def forward(self, x: torch.Tensor):

    初期畳み込みとバッチ正規化

    319        x = self.bn(self.conv(x))

    残り (またはボトルネック) ブロック

    321        x = self.blocks(x)

    x [batch_size, channels, h, w] 形状を次のように変更 [batch_size, channels, h * w]

    323        x = x.view(x.shape[0], x.shape[1], -1)

    グローバルアベレージプーリング

    325        return x.mean(dim=-1)