| |
| """Modified from https://github.com/MichaelFan01/STDC-Seg.""" |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule |
| from mmengine.model import BaseModule, ModuleList, Sequential |
|
|
| from mmseg.registry import MODELS |
| from ..utils import resize |
| from .bisenetv1 import AttentionRefinementModule |
|
|
|
|
| class STDCModule(BaseModule): |
| """STDCModule. |
| |
| Args: |
| in_channels (int): The number of input channels. |
| out_channels (int): The number of output channels before scaling. |
| stride (int): The number of stride for the first conv layer. |
| norm_cfg (dict): Config dict for normalization layer. Default: None. |
| act_cfg (dict): The activation config for conv layers. |
| num_convs (int): Numbers of conv layers. |
| fusion_type (str): Type of fusion operation. Default: 'add'. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| stride, |
| norm_cfg=None, |
| act_cfg=None, |
| num_convs=4, |
| fusion_type='add', |
| init_cfg=None): |
| super().__init__(init_cfg=init_cfg) |
| assert num_convs > 1 |
| assert fusion_type in ['add', 'cat'] |
| self.stride = stride |
| self.with_downsample = True if self.stride == 2 else False |
| self.fusion_type = fusion_type |
|
|
| self.layers = ModuleList() |
| conv_0 = ConvModule( |
| in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) |
|
|
| if self.with_downsample: |
| self.downsample = ConvModule( |
| out_channels // 2, |
| out_channels // 2, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| groups=out_channels // 2, |
| norm_cfg=norm_cfg, |
| act_cfg=None) |
|
|
| if self.fusion_type == 'add': |
| self.layers.append(nn.Sequential(conv_0, self.downsample)) |
| self.skip = Sequential( |
| ConvModule( |
| in_channels, |
| in_channels, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| groups=in_channels, |
| norm_cfg=norm_cfg, |
| act_cfg=None), |
| ConvModule( |
| in_channels, |
| out_channels, |
| 1, |
| norm_cfg=norm_cfg, |
| act_cfg=None)) |
| else: |
| self.layers.append(conv_0) |
| self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) |
| else: |
| self.layers.append(conv_0) |
|
|
| for i in range(1, num_convs): |
| out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i |
| self.layers.append( |
| ConvModule( |
| out_channels // 2**i, |
| out_channels // out_factor, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg)) |
|
|
| def forward(self, inputs): |
| if self.fusion_type == 'add': |
| out = self.forward_add(inputs) |
| else: |
| out = self.forward_cat(inputs) |
| return out |
|
|
| def forward_add(self, inputs): |
| layer_outputs = [] |
| x = inputs.clone() |
| for layer in self.layers: |
| x = layer(x) |
| layer_outputs.append(x) |
| if self.with_downsample: |
| inputs = self.skip(inputs) |
|
|
| return torch.cat(layer_outputs, dim=1) + inputs |
|
|
| def forward_cat(self, inputs): |
| x0 = self.layers[0](inputs) |
| layer_outputs = [x0] |
| for i, layer in enumerate(self.layers[1:]): |
| if i == 0: |
| if self.with_downsample: |
| x = layer(self.downsample(x0)) |
| else: |
| x = layer(x0) |
| else: |
| x = layer(x) |
| layer_outputs.append(x) |
| if self.with_downsample: |
| layer_outputs[0] = self.skip(x0) |
| return torch.cat(layer_outputs, dim=1) |
|
|
|
|
| class FeatureFusionModule(BaseModule): |
| """Feature Fusion Module. This module is different from FeatureFusionModule |
| in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter |
| channel number is calculated by given `scale_factor`, while |
| FeatureFusionModule in BiSeNetV1 only uses one ConvModule in |
| `self.conv_atten`. |
| |
| Args: |
| in_channels (int): The number of input channels. |
| out_channels (int): The number of output channels. |
| scale_factor (int): The number of channel scale factor. |
| Default: 4. |
| norm_cfg (dict): Config dict for normalization layer. |
| Default: dict(type='BN'). |
| act_cfg (dict): The activation config for conv layers. |
| Default: dict(type='ReLU'). |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None. |
| """ |
|
|
| def __init__(self, |
| in_channels, |
| out_channels, |
| scale_factor=4, |
| norm_cfg=dict(type='BN'), |
| act_cfg=dict(type='ReLU'), |
| init_cfg=None): |
| super().__init__(init_cfg=init_cfg) |
| channels = out_channels // scale_factor |
| self.conv0 = ConvModule( |
| in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) |
| self.attention = nn.Sequential( |
| nn.AdaptiveAvgPool2d((1, 1)), |
| ConvModule( |
| out_channels, |
| channels, |
| 1, |
| norm_cfg=None, |
| bias=False, |
| act_cfg=act_cfg), |
| ConvModule( |
| channels, |
| out_channels, |
| 1, |
| norm_cfg=None, |
| bias=False, |
| act_cfg=None), nn.Sigmoid()) |
|
|
| def forward(self, spatial_inputs, context_inputs): |
| inputs = torch.cat([spatial_inputs, context_inputs], dim=1) |
| x = self.conv0(inputs) |
| attn = self.attention(x) |
| x_attn = x * attn |
| return x_attn + x |
|
|
|
|
| @MODELS.register_module() |
| class STDCNet(BaseModule): |
| """This backbone is the implementation of `Rethinking BiSeNet For Real-time |
| Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. |
| |
| Args: |
| stdc_type (int): The type of backbone structure, |
| `STDCNet1` and`STDCNet2` denotes two main backbones in paper, |
| whose FLOPs is 813M and 1446M, respectively. |
| in_channels (int): The num of input_channels. |
| channels (tuple[int]): The output channels for each stage. |
| bottleneck_type (str): The type of STDC Module type, the value must |
| be 'add' or 'cat'. |
| norm_cfg (dict): Config dict for normalization layer. |
| act_cfg (dict): The activation config for conv layers. |
| num_convs (int): Numbers of conv layer at each STDC Module. |
| Default: 4. |
| with_final_conv (bool): Whether add a conv layer at the Module output. |
| Default: True. |
| pretrained (str, optional): Model pretrained path. Default: None. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None. |
| |
| Example: |
| >>> import torch |
| >>> stdc_type = 'STDCNet1' |
| >>> in_channels = 3 |
| >>> channels = (32, 64, 256, 512, 1024) |
| >>> bottleneck_type = 'cat' |
| >>> inputs = torch.rand(1, 3, 1024, 2048) |
| >>> self = STDCNet(stdc_type, in_channels, |
| ... channels, bottleneck_type).eval() |
| >>> outputs = self.forward(inputs) |
| >>> for i in range(len(outputs)): |
| ... print(f'outputs[{i}].shape = {outputs[i].shape}') |
| outputs[0].shape = torch.Size([1, 256, 128, 256]) |
| outputs[1].shape = torch.Size([1, 512, 64, 128]) |
| outputs[2].shape = torch.Size([1, 1024, 32, 64]) |
| """ |
|
|
| arch_settings = { |
| 'STDCNet1': [(2, 1), (2, 1), (2, 1)], |
| 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] |
| } |
|
|
| def __init__(self, |
| stdc_type, |
| in_channels, |
| channels, |
| bottleneck_type, |
| norm_cfg, |
| act_cfg, |
| num_convs=4, |
| with_final_conv=False, |
| pretrained=None, |
| init_cfg=None): |
| super().__init__(init_cfg=init_cfg) |
| assert stdc_type in self.arch_settings, \ |
| f'invalid structure {stdc_type} for STDCNet.' |
| assert bottleneck_type in ['add', 'cat'],\ |
| f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' |
|
|
| assert len(channels) == 5,\ |
| f'invalid channels length {len(channels)} for STDCNet.' |
|
|
| self.in_channels = in_channels |
| self.channels = channels |
| self.stage_strides = self.arch_settings[stdc_type] |
| self.prtrained = pretrained |
| self.num_convs = num_convs |
| self.with_final_conv = with_final_conv |
|
|
| self.stages = ModuleList([ |
| ConvModule( |
| self.in_channels, |
| self.channels[0], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg), |
| ConvModule( |
| self.channels[0], |
| self.channels[1], |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg) |
| ]) |
| |
| |
| |
| |
| |
| |
| self.num_shallow_features = len(self.stages) |
|
|
| for strides in self.stage_strides: |
| idx = len(self.stages) - 1 |
| self.stages.append( |
| self._make_stage(self.channels[idx], self.channels[idx + 1], |
| strides, norm_cfg, act_cfg, bottleneck_type)) |
| |
| |
| |
| |
| if self.with_final_conv: |
| self.final_conv = ConvModule( |
| self.channels[-1], |
| max(1024, self.channels[-1]), |
| 1, |
| norm_cfg=norm_cfg, |
| act_cfg=act_cfg) |
|
|
| def _make_stage(self, in_channels, out_channels, strides, norm_cfg, |
| act_cfg, bottleneck_type): |
| layers = [] |
| for i, stride in enumerate(strides): |
| layers.append( |
| STDCModule( |
| in_channels if i == 0 else out_channels, |
| out_channels, |
| stride, |
| norm_cfg, |
| act_cfg, |
| num_convs=self.num_convs, |
| fusion_type=bottleneck_type)) |
| return Sequential(*layers) |
|
|
| def forward(self, x): |
| outs = [] |
| for stage in self.stages: |
| x = stage(x) |
| outs.append(x) |
| if self.with_final_conv: |
| outs[-1] = self.final_conv(outs[-1]) |
| outs = outs[self.num_shallow_features:] |
| return tuple(outs) |
|
|
|
|
| @MODELS.register_module() |
| class STDCContextPathNet(BaseModule): |
| """STDCNet with Context Path. The `outs` below is a list of three feature |
| maps from deep to shallow, whose height and width is from small to big, |
| respectively. The biggest feature map of `outs` is outputted for |
| `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. |
| The other two feature maps are used for Attention Refinement Module, |
| respectively. Besides, the biggest feature map of `outs` and the last |
| output of Attention Refinement Module are concatenated for Feature Fusion |
| Module. Then, this fusion feature map `feat_fuse` would be outputted for |
| `decode_head`. More details please refer to Figure 4 of original paper. |
| |
| Args: |
| backbone_cfg (dict): Config dict for stdc backbone. |
| last_in_channels (tuple(int)), The number of channels of last |
| two feature maps from stdc backbone. Default: (1024, 512). |
| out_channels (int): The channels of output feature maps. |
| Default: 128. |
| ffm_cfg (dict): Config dict for Feature Fusion Module. Default: |
| `dict(in_channels=512, out_channels=256, scale_factor=4)`. |
| upsample_mode (str): Algorithm used for upsampling: |
| ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | |
| ``'trilinear'``. Default: ``'nearest'``. |
| align_corners (str): align_corners argument of F.interpolate. It |
| must be `None` if upsample_mode is ``'nearest'``. Default: None. |
| norm_cfg (dict): Config dict for normalization layer. |
| Default: dict(type='BN'). |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| Default: None. |
| |
| Return: |
| outputs (tuple): The tuple of list of output feature map for |
| auxiliary heads and decoder head. |
| """ |
|
|
| def __init__(self, |
| backbone_cfg, |
| last_in_channels=(1024, 512), |
| out_channels=128, |
| ffm_cfg=dict( |
| in_channels=512, out_channels=256, scale_factor=4), |
| upsample_mode='nearest', |
| align_corners=None, |
| norm_cfg=dict(type='BN'), |
| init_cfg=None): |
| super().__init__(init_cfg=init_cfg) |
| self.backbone = MODELS.build(backbone_cfg) |
| self.arms = ModuleList() |
| self.convs = ModuleList() |
| for channels in last_in_channels: |
| self.arms.append(AttentionRefinementModule(channels, out_channels)) |
| self.convs.append( |
| ConvModule( |
| out_channels, |
| out_channels, |
| 3, |
| padding=1, |
| norm_cfg=norm_cfg)) |
| self.conv_avg = ConvModule( |
| last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) |
|
|
| self.ffm = FeatureFusionModule(**ffm_cfg) |
|
|
| self.upsample_mode = upsample_mode |
| self.align_corners = align_corners |
|
|
| def forward(self, x): |
| outs = list(self.backbone(x)) |
| avg = F.adaptive_avg_pool2d(outs[-1], 1) |
| avg_feat = self.conv_avg(avg) |
|
|
| feature_up = resize( |
| avg_feat, |
| size=outs[-1].shape[2:], |
| mode=self.upsample_mode, |
| align_corners=self.align_corners) |
| arms_out = [] |
| for i in range(len(self.arms)): |
| x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up |
| feature_up = resize( |
| x_arm, |
| size=outs[len(outs) - 1 - i - 1].shape[2:], |
| mode=self.upsample_mode, |
| align_corners=self.align_corners) |
| feature_up = self.convs[i](feature_up) |
| arms_out.append(feature_up) |
|
|
| feat_fuse = self.ffm(outs[0], arms_out[1]) |
|
|
| |
| |
| |
| |
| outputs = [outs[0]] + list(arms_out) + [feat_fuse] |
| return tuple(outputs) |
|
|