具有权重标准化的 2D 卷积层

这是具有权重标准化的二维卷积层的实现

13import torch
14import torch.nn as nn
15from torch.nn import functional as F
16
17from labml_nn.normalization.weight_standardization import weight_standardization

2D 卷积层

这将扩展标准 2D 卷积层,并在卷积步骤之前标准化权重。

20class Conv2d(nn.Conv2d):
26    def __init__(self, in_channels, out_channels, kernel_size,
27                 stride=1,
28                 padding=0,
29                 dilation=1,
30                 groups: int = 1,
31                 bias: bool = True,
32                 padding_mode: str = 'zeros',
33                 eps: float = 1e-5):
34        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
35                                     stride=stride,
36                                     padding=padding,
37                                     dilation=dilation,
38                                     groups=groups,
39                                     bias=bias,
40                                     padding_mode=padding_mode)
41        self.eps = eps
43    def forward(self, x: torch.Tensor):
44        return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
45                        self.padding, self.dilation, self.groups)

验证张量大小的简单测试

48def _test():
52    conv2d = Conv2d(10, 20, 5)
53    from labml.logger import inspect
54    inspect(conv2d.weight)
55    import torch
56    inspect(conv2d(torch.zeros(10, 10, 100, 100)))
57
58
59if __name__ == '__main__':
60    _test()