13import torch
14import torch.nn as nn
15from torch.nn import functional as F
16
17from labml_nn.normalization.weight_standardization import weight_standardization
20class Conv2d(nn.Conv2d):
26 def __init__(self, in_channels, out_channels, kernel_size,
27 stride=1,
28 padding=0,
29 dilation=1,
30 groups: int = 1,
31 bias: bool = True,
32 padding_mode: str = 'zeros',
33 eps: float = 1e-5):
34 super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,
35 stride=stride,
36 padding=padding,
37 dilation=dilation,
38 groups=groups,
39 bias=bias,
40 padding_mode=padding_mode)
41 self.eps = eps
43 def forward(self, x: torch.Tensor):
44 return F.conv2d(x, weight_standardization(self.weight, self.eps), self.bias, self.stride,
45 self.padding, self.dilation, self.groups)
テンソルサイズを検証する簡単なテスト
48def _test():
52 conv2d = Conv2d(10, 20, 5)
53 from labml.logger import inspect
54 inspect(conv2d.weight)
55 import torch
56 inspect(conv2d(torch.zeros(10, 10, 100, 100)))
57
58
59if __name__ == '__main__':
60 _test()