| |
| |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.nn.modules.utils import _pair |
|
|
|
|
| class MultiScaleTridentConv(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| strides=1, |
| paddings=0, |
| dilations=1, |
| dilation=1, |
| groups=1, |
| num_branch=1, |
| test_branch_idx=-1, |
| bias=False, |
| norm=None, |
| activation=None, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = _pair(kernel_size) |
| self.num_branch = num_branch |
| self.stride = _pair(stride) |
| self.groups = groups |
| self.with_bias = bias |
| self.dilation = dilation |
| if isinstance(paddings, int): |
| paddings = [paddings] * self.num_branch |
| if isinstance(dilations, int): |
| dilations = [dilations] * self.num_branch |
| if isinstance(strides, int): |
| strides = [strides] * self.num_branch |
| self.paddings = [_pair(padding) for padding in paddings] |
| self.dilations = [_pair(dilation) for dilation in dilations] |
| self.strides = [_pair(stride) for stride in strides] |
| self.test_branch_idx = test_branch_idx |
| self.norm = norm |
| self.activation = activation |
|
|
| assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 |
|
|
| self.weight = nn.Parameter( |
| torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) |
| ) |
| if bias: |
| self.bias = nn.Parameter(torch.Tensor(out_channels)) |
| else: |
| self.bias = None |
|
|
| nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") |
| if self.bias is not None: |
| nn.init.constant_(self.bias, 0) |
|
|
| def forward(self, inputs): |
| num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 |
| assert len(inputs) == num_branch |
|
|
| if self.training or self.test_branch_idx == -1: |
| outputs = [ |
| F.conv2d( |
| input, self.weight, self.bias, stride, padding, self.dilation, self.groups |
| ) |
| for input, stride, padding in zip(inputs, self.strides, self.paddings) |
| ] |
| else: |
| outputs = [ |
| F.conv2d( |
| inputs[0], |
| self.weight, |
| self.bias, |
| self.strides[self.test_branch_idx] |
| if self.test_branch_idx == -1 |
| else self.strides[-1], |
| self.paddings[self.test_branch_idx] |
| if self.test_branch_idx == -1 |
| else self.paddings[-1], |
| self.dilation, |
| self.groups, |
| ) |
| ] |
|
|
| if self.norm is not None: |
| outputs = [self.norm(x) for x in outputs] |
| if self.activation is not None: |
| outputs = [self.activation(x) for x in outputs] |
| return outputs |
|
|