| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as cp |
| from mmcv.cnn import build_conv_layer, build_norm_layer |
| from mmengine.model import BaseModule |
| from torch.nn.modules.utils import _pair |
|
|
| from mmdet.models.backbones.resnet import Bottleneck, ResNet |
| from mmdet.registry import MODELS |
|
|
|
|
| class TridentConv(BaseModule): |
| """Trident Convolution Module. |
| |
| Args: |
| in_channels (int): Number of channels in input. |
| out_channels (int): Number of channels in output. |
| kernel_size (int): Size of convolution kernel. |
| stride (int, optional): Convolution stride. Default: 1. |
| trident_dilations (tuple[int, int, int], optional): Dilations of |
| different trident branch. Default: (1, 2, 3). |
| test_branch_idx (int, optional): In inference, all 3 branches will |
| be used if `test_branch_idx==-1`, otherwise only branch with |
| index `test_branch_idx` will be used. Default: 1. |
| bias (bool, optional): Whether to use bias in convolution or not. |
| Default: False. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| trident_dilations=(1, 2, 3), |
| test_branch_idx=1, |
| bias=False, |
| init_cfg=None): |
| super(TridentConv, self).__init__(init_cfg) |
| self.num_branch = len(trident_dilations) |
| self.with_bias = bias |
| self.test_branch_idx = test_branch_idx |
| self.stride = _pair(stride) |
| self.kernel_size = _pair(kernel_size) |
| self.paddings = _pair(trident_dilations) |
| self.dilations = trident_dilations |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.bias = bias |
|
|
| self.weight = nn.Parameter( |
| torch.Tensor(out_channels, in_channels, *self.kernel_size)) |
| if bias: |
| self.bias = nn.Parameter(torch.Tensor(out_channels)) |
| else: |
| self.bias = None |
|
|
| def extra_repr(self): |
| tmpstr = f'in_channels={self.in_channels}' |
| tmpstr += f', out_channels={self.out_channels}' |
| tmpstr += f', kernel_size={self.kernel_size}' |
| tmpstr += f', num_branch={self.num_branch}' |
| tmpstr += f', test_branch_idx={self.test_branch_idx}' |
| tmpstr += f', stride={self.stride}' |
| tmpstr += f', paddings={self.paddings}' |
| tmpstr += f', dilations={self.dilations}' |
| tmpstr += f', bias={self.bias}' |
| return tmpstr |
|
|
| def forward(self, inputs): |
| if self.training or self.test_branch_idx == -1: |
| outputs = [ |
| F.conv2d(input, self.weight, self.bias, self.stride, padding, |
| dilation) for input, dilation, padding in zip( |
| inputs, self.dilations, self.paddings) |
| ] |
| else: |
| assert len(inputs) == 1 |
| outputs = [ |
| F.conv2d(inputs[0], self.weight, self.bias, self.stride, |
| self.paddings[self.test_branch_idx], |
| self.dilations[self.test_branch_idx]) |
| ] |
|
|
| return outputs |
|
|
|
|
| |
| |
| class TridentBottleneck(Bottleneck): |
| """BottleBlock for TridentResNet. |
| |
| Args: |
| trident_dilations (tuple[int, int, int]): Dilations of different |
| trident branch. |
| test_branch_idx (int): In inference, all 3 branches will be used |
| if `test_branch_idx==-1`, otherwise only branch with index |
| `test_branch_idx` will be used. |
| concat_output (bool): Whether to concat the output list to a Tensor. |
| `True` only in the last Block. |
| """ |
|
|
| def __init__(self, trident_dilations, test_branch_idx, concat_output, |
| **kwargs): |
|
|
| super(TridentBottleneck, self).__init__(**kwargs) |
| self.trident_dilations = trident_dilations |
| self.num_branch = len(trident_dilations) |
| self.concat_output = concat_output |
| self.test_branch_idx = test_branch_idx |
| self.conv2 = TridentConv( |
| self.planes, |
| self.planes, |
| kernel_size=3, |
| stride=self.conv2_stride, |
| bias=False, |
| trident_dilations=self.trident_dilations, |
| test_branch_idx=test_branch_idx, |
| init_cfg=dict( |
| type='Kaiming', |
| distribution='uniform', |
| mode='fan_in', |
| override=dict(name='conv2'))) |
|
|
| def forward(self, x): |
|
|
| def _inner_forward(x): |
| num_branch = ( |
| self.num_branch |
| if self.training or self.test_branch_idx == -1 else 1) |
| identity = x |
| if not isinstance(x, list): |
| x = (x, ) * num_branch |
| identity = x |
| if self.downsample is not None: |
| identity = [self.downsample(b) for b in x] |
|
|
| out = [self.conv1(b) for b in x] |
| out = [self.norm1(b) for b in out] |
| out = [self.relu(b) for b in out] |
|
|
| if self.with_plugins: |
| for k in range(len(out)): |
| out[k] = self.forward_plugin(out[k], |
| self.after_conv1_plugin_names) |
|
|
| out = self.conv2(out) |
| out = [self.norm2(b) for b in out] |
| out = [self.relu(b) for b in out] |
| if self.with_plugins: |
| for k in range(len(out)): |
| out[k] = self.forward_plugin(out[k], |
| self.after_conv2_plugin_names) |
|
|
| out = [self.conv3(b) for b in out] |
| out = [self.norm3(b) for b in out] |
|
|
| if self.with_plugins: |
| for k in range(len(out)): |
| out[k] = self.forward_plugin(out[k], |
| self.after_conv3_plugin_names) |
|
|
| out = [ |
| out_b + identity_b for out_b, identity_b in zip(out, identity) |
| ] |
| return out |
|
|
| if self.with_cp and x.requires_grad: |
| out = cp.checkpoint(_inner_forward, x) |
| else: |
| out = _inner_forward(x) |
|
|
| out = [self.relu(b) for b in out] |
| if self.concat_output: |
| out = torch.cat(out, dim=0) |
| return out |
|
|
|
|
| def make_trident_res_layer(block, |
| inplanes, |
| planes, |
| num_blocks, |
| stride=1, |
| trident_dilations=(1, 2, 3), |
| style='pytorch', |
| with_cp=False, |
| conv_cfg=None, |
| norm_cfg=dict(type='BN'), |
| dcn=None, |
| plugins=None, |
| test_branch_idx=-1): |
| """Build Trident Res Layers.""" |
|
|
| downsample = None |
| if stride != 1 or inplanes != planes * block.expansion: |
| downsample = [] |
| conv_stride = stride |
| downsample.extend([ |
| build_conv_layer( |
| conv_cfg, |
| inplanes, |
| planes * block.expansion, |
| kernel_size=1, |
| stride=conv_stride, |
| bias=False), |
| build_norm_layer(norm_cfg, planes * block.expansion)[1] |
| ]) |
| downsample = nn.Sequential(*downsample) |
|
|
| layers = [] |
| for i in range(num_blocks): |
| layers.append( |
| block( |
| inplanes=inplanes, |
| planes=planes, |
| stride=stride if i == 0 else 1, |
| trident_dilations=trident_dilations, |
| downsample=downsample if i == 0 else None, |
| style=style, |
| with_cp=with_cp, |
| conv_cfg=conv_cfg, |
| norm_cfg=norm_cfg, |
| dcn=dcn, |
| plugins=plugins, |
| test_branch_idx=test_branch_idx, |
| concat_output=True if i == num_blocks - 1 else False)) |
| inplanes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
|
|
| @MODELS.register_module() |
| class TridentResNet(ResNet): |
| """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to |
| ResNet, while in stage 3, Trident BottleBlock is utilized to replace the |
| normal BottleBlock to yield trident output. Different branch shares the |
| convolution weight but uses different dilations to achieve multi-scale |
| output. |
| |
| / stage3(b0) \ |
| x - stem - stage1 - stage2 - stage3(b1) - output |
| \ stage3(b2) / |
| |
| Args: |
| depth (int): Depth of resnet, from {50, 101, 152}. |
| num_branch (int): Number of branches in TridentNet. |
| test_branch_idx (int): In inference, all 3 branches will be used |
| if `test_branch_idx==-1`, otherwise only branch with index |
| `test_branch_idx` will be used. |
| trident_dilations (tuple[int]): Dilations of different trident branch. |
| len(trident_dilations) should be equal to num_branch. |
| """ |
|
|
| def __init__(self, depth, num_branch, test_branch_idx, trident_dilations, |
| **kwargs): |
|
|
| assert num_branch == len(trident_dilations) |
| assert depth in (50, 101, 152) |
| super(TridentResNet, self).__init__(depth, **kwargs) |
| assert self.num_stages == 3 |
| self.test_branch_idx = test_branch_idx |
| self.num_branch = num_branch |
|
|
| last_stage_idx = self.num_stages - 1 |
| stride = self.strides[last_stage_idx] |
| dilation = trident_dilations |
| dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None |
| if self.plugins is not None: |
| stage_plugins = self.make_stage_plugins(self.plugins, |
| last_stage_idx) |
| else: |
| stage_plugins = None |
| planes = self.base_channels * 2**last_stage_idx |
| res_layer = make_trident_res_layer( |
| TridentBottleneck, |
| inplanes=(self.block.expansion * self.base_channels * |
| 2**(last_stage_idx - 1)), |
| planes=planes, |
| num_blocks=self.stage_blocks[last_stage_idx], |
| stride=stride, |
| trident_dilations=dilation, |
| style=self.style, |
| with_cp=self.with_cp, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| dcn=dcn, |
| plugins=stage_plugins, |
| test_branch_idx=self.test_branch_idx) |
|
|
| layer_name = f'layer{last_stage_idx + 1}' |
|
|
| self.__setattr__(layer_name, res_layer) |
| self.res_layers.pop(last_stage_idx) |
| self.res_layers.insert(last_stage_idx, layer_name) |
|
|
| self._freeze_stages() |
|
|