This is a PyTorch implementation of the paper Deep Residual Learning for Image Recognition.
ResNets train layers as residual functions to overcome the degradation problem. The degradation problem is the accuracy of deep neural networks degrading when the number of layers becomes very high. The accuracy increases as the number of layers increase, then saturates, and then starts to degrade.
The paper argues that deeper models should perform at least as well as shallower models because the extra layers can just learn to perform an identity mapping.
If is the mapping that needs to be learned by a few layers, they train the residual function
instead. And the original function becomes .
In this case, learning identity mapping for is equivalent to learning to be , which is easier to learn.
In the parameterized form this can be written as,
and when the feature map sizes of and are different the paper suggests doing a linear projection, with learned weights .
Paper experimented with zero padding instead of linear projections and found linear projections to work better. Also when the feature map sizes match they found identity mapping to be better than linear projections.
should have more than one layer, otherwise the sum also won't have non-linearities and will be like a linear layer.
Here is the training code for training a ResNet on CIFAR-10.
55from typing import List, Optional
56
57import torch
58from torch import nn
59
60from labml_helpers.module import Module
63class ShortcutProjection(Module):
in_channels
is the number of channels in out_channels
is the number of channels in stride
is the stride length in the convolution operation for . We do the same stride on the shortcut connection, to match the feature-map size.70 def __init__(self, in_channels: int, out_channels: int, stride: int):
77 super().__init__()
Convolution layer for linear projection
80 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
Paper suggests adding batch normalization after each convolution operation
82 self.bn = nn.BatchNorm2d(out_channels)
84 def forward(self, x: torch.Tensor):
Convolution and batch normalization
86 return self.bn(self.conv(x))
This implements the residual block described in the paper. It has two convolution layers.
The first convolution layer maps from in_channels
to out_channels
, where the out_channels
is higher than in_channels
when we reduce the feature map size with a stride length greater than .
The second convolution layer maps from out_channels
to out_channels
and always has a stride length of 1.
Both convolution layers are followed by batch normalization.
89class ResidualBlock(Module):
in_channels
is the number of channels in out_channels
is the number of output channels stride
is the stride length in the convolution operation.110 def __init__(self, in_channels: int, out_channels: int, stride: int):
116 super().__init__()
First convolution layer, this maps to out_channels
119 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
Batch normalization after the first convolution
121 self.bn1 = nn.BatchNorm2d(out_channels)
First activation function (ReLU)
123 self.act1 = nn.ReLU()
Second convolution layer
126 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
Batch normalization after the second convolution
128 self.bn2 = nn.BatchNorm2d(out_channels)
Shortcut connection should be a projection if the stride length is not or if the number of channels change
132 if stride != 1 or in_channels != out_channels:
Projection
134 self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
135 else:
Identity
137 self.shortcut = nn.Identity()
Second activation function (ReLU) (after adding the shortcut)
140 self.act2 = nn.ReLU()
x
is the input of shape [batch_size, in_channels, height, width]
142 def forward(self, x: torch.Tensor):
Get the shortcut connection
147 shortcut = self.shortcut(x)
First convolution and activation
149 x = self.act1(self.bn1(self.conv1(x)))
Second convolution
151 x = self.bn2(self.conv2(x))
Activation function after adding the shortcut
153 return self.act2(x + shortcut)
This implements the bottleneck block described in the paper. It has , , and convolution layers.
The first convolution layer maps from in_channels
to bottleneck_channels
with a convolution, where the bottleneck_channels
is lower than in_channels
.
The second convolution layer maps from bottleneck_channels
to bottleneck_channels
. This can have a stride length greater than when we want to compress the feature map size.
The third, final convolution layer maps to out_channels
. out_channels
is higher than in_channels
if the stride length is greater than ; otherwise, is equal to in_channels
.
bottleneck_channels
is less than in_channels
and the convolution is performed on this shrunk space (hence the bottleneck). The two convolution decreases and increases the number of channels.
156class BottleneckResidualBlock(Module):
in_channels
is the number of channels in bottleneck_channels
is the number of channels for the convlution out_channels
is the number of output channels stride
is the stride length in the convolution operation.184 def __init__(self, in_channels: int, bottleneck_channels: int, out_channels: int, stride: int):
191 super().__init__()
First convolution layer, this maps to bottleneck_channels
194 self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, stride=1)
Batch normalization after the first convolution
196 self.bn1 = nn.BatchNorm2d(bottleneck_channels)
First activation function (ReLU)
198 self.act1 = nn.ReLU()
Second convolution layer
201 self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride, padding=1)
Batch normalization after the second convolution
203 self.bn2 = nn.BatchNorm2d(bottleneck_channels)
Second activation function (ReLU)
205 self.act2 = nn.ReLU()
Third convolution layer, this maps to out_channels
.
208 self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, stride=1)
Batch normalization after the second convolution
210 self.bn3 = nn.BatchNorm2d(out_channels)
Shortcut connection should be a projection if the stride length is not or if the number of channels change
214 if stride != 1 or in_channels != out_channels:
Projection
216 self.shortcut = ShortcutProjection(in_channels, out_channels, stride)
217 else:
Identity
219 self.shortcut = nn.Identity()
Second activation function (ReLU) (after adding the shortcut)
222 self.act3 = nn.ReLU()
x
is the input of shape [batch_size, in_channels, height, width]
224 def forward(self, x: torch.Tensor):
Get the shortcut connection
229 shortcut = self.shortcut(x)
First convolution and activation
231 x = self.act1(self.bn1(self.conv1(x)))
Second convolution and activation
233 x = self.act2(self.bn2(self.conv2(x)))
Third convolution
235 x = self.bn3(self.conv3(x))
Activation function after adding the shortcut
237 return self.act3(x + shortcut)
This is a the base of the resnet model without the final linear layer and softmax for classification.
The resnet is made of stacked residual blocks or bottleneck residual blocks. The feature map size is halved after a few blocks with a block of stride length . The number of channels is increased when the feature map size is reduced. Finally the feature map is average pooled to get a vector representation.
240class ResNetBase(Module):
n_blocks
is a list of of number of blocks for each feature map size. n_channels
is the number of channels for each feature map size. bottlenecks
is the number of channels the bottlenecks. If this is None
, residual blocks are used. img_channels
is the number of channels in the input. first_kernel_size
is the kernel size of the initial convolution layer254 def __init__(self, n_blocks: List[int], n_channels: List[int],
255 bottlenecks: Optional[List[int]] = None,
256 img_channels: int = 3, first_kernel_size: int = 7):
265 super().__init__()
Number of blocks and number of channels for each feature map size
268 assert len(n_blocks) == len(n_channels)
If bottleneck residual blocks are used, the number of channels in bottlenecks should be provided for each feature map size
271 assert bottlenecks is None or len(bottlenecks) == len(n_channels)
Initial convolution layer maps from img_channels
to number of channels in the first residual block (n_channels[0]
)
275 self.conv = nn.Conv2d(img_channels, n_channels[0],
276 kernel_size=first_kernel_size, stride=2, padding=first_kernel_size // 2)
Batch norm after initial convolution
278 self.bn = nn.BatchNorm2d(n_channels[0])
List of blocks
281 blocks = []
Number of channels from previous layer (or block)
283 prev_channels = n_channels[0]
Loop through each feature map size
285 for i, channels in enumerate(n_channels):
The first block for the new feature map size, will have a stride length of except fro the very first block
288 stride = 2 if len(blocks) == 0 else 1
289
290 if bottlenecks is None:
residual blocks that maps from prev_channels
to channels
292 blocks.append(ResidualBlock(prev_channels, channels, stride=stride))
293 else:
bottleneck residual blocks that maps from prev_channels
to channels
296 blocks.append(BottleneckResidualBlock(prev_channels, bottlenecks[i], channels,
297 stride=stride))
Change the number of channels
300 prev_channels = channels
Add rest of the blocks - no change in feature map size or channels
302 for _ in range(n_blocks[i] - 1):
303 if bottlenecks is None:
305 blocks.append(ResidualBlock(channels, channels, stride=1))
306 else:
308 blocks.append(BottleneckResidualBlock(channels, bottlenecks[i], channels, stride=1))
Stack the blocks
311 self.blocks = nn.Sequential(*blocks)
x
has shape [batch_size, img_channels, height, width]
313 def forward(self, x: torch.Tensor):
Initial convolution and batch normalization
319 x = self.bn(self.conv(x))
Residual (or bottleneck) blocks
321 x = self.blocks(x)
Change x
from shape [batch_size, channels, h, w]
to [batch_size, channels, h * w]
323 x = x.view(x.shape[0], x.shape[1], -1)
Global average pooling
325 return x.mean(dim=-1)