| import torch |
| import torch.nn as nn |
| from typing import Optional, Callable, Union, Tuple, Any |
| import torch |
| from torch import nn, Tensor |
| import numpy as np |
| from typing import Optional |
| import math |
| from torch import nn |
|
|
| def makeDivisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
| if min_value is None: |
| min_value = divisor |
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
| if new_v < 0.9 * v: |
| new_v += divisor |
| return new_v |
| def callMethod(self, ElementName): |
| return getattr(self, ElementName) |
| def setMethod(self, ElementName, ElementValue): |
| return setattr(self, ElementName, ElementValue) |
| def shuffleTensor(Feature: Tensor, Mode: int=1) -> Tensor: |
| if isinstance(Feature, Tensor): |
| Feature = [Feature] |
| Indexs = None |
| Output = [] |
| for f in Feature: |
| B, C, H, W = f.shape |
| if Mode == 1: |
| f = f.flatten(2) |
| if Indexs is None: |
| Indexs = torch.randperm(f.shape[-1], device=f.device) |
| f = f[:, :, Indexs.to(f.device)] |
| f = f.reshape(B, C, H, W) |
| else: |
| if Indexs is None: |
| Indexs = [torch.randperm(H, device=f.device), |
| torch.randperm(W, device=f.device)] |
| f = f[:, :, Indexs[0].to(f.device)] |
| f = f[:, :, :, Indexs[1].to(f.device)] |
| Output.append(f) |
| return Output |
| class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): |
| def __init__(self, output_size: int or tuple=1): |
| super(AdaptiveAvgPool2d, self).__init__(output_size=output_size) |
|
|
| def profileModule(self, Input: Tensor): |
| Output = self.forward(Input) |
| return Output, 0.0, 0.0 |
|
|
| class AdaptiveMaxPool2d(nn.AdaptiveMaxPool2d): |
| def __init__(self, output_size: int or tuple=1): |
| super(AdaptiveMaxPool2d, self).__init__(output_size=output_size) |
|
|
| def profileModule(self, Input: Tensor): |
| Output = self.forward(Input) |
| return Output, 0.0, 0.0 |
| class BaseConv2d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: Optional[int] = 1, |
| padding: Optional[int] = None, |
| groups: Optional[int] = 1, |
| bias: Optional[bool] = None, |
| BNorm: bool = False, |
| ActLayer: Optional[Callable[..., nn.Module]] = None, |
| dilation: int = 1, |
| Momentum: Optional[float] = 0.1, |
| **kwargs: Any |
| ) -> None: |
| super(BaseConv2d, self).__init__() |
| if padding is None: |
| padding = int((kernel_size - 1) // 2 * dilation) |
|
|
| if bias is None: |
| bias = not BNorm |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.groups = groups |
| self.bias = bias |
|
|
| self.Conv = nn.Conv2d(in_channels, out_channels, |
| kernel_size, stride, padding, dilation, groups, bias, **kwargs) |
|
|
| self.Bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=Momentum) if BNorm else nn.Identity() |
|
|
| if ActLayer is not None: |
| if isinstance(list(ActLayer().named_modules())[0][1], nn.Sigmoid): |
| self.Act = ActLayer() |
| else: |
| self.Act = ActLayer(inplace=True) |
| else: |
| self.Act = ActLayer |
|
|
| self.apply(initWeight) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.Conv(x) |
| x = self.Bn(x) |
| if self.Act is not None: |
| x = self.Act(x) |
| return x |
|
|
| NormLayerTuple = ( |
| nn.BatchNorm1d, |
| nn.BatchNorm2d, |
| nn.SyncBatchNorm, |
| nn.LayerNorm, |
| nn.InstanceNorm1d, |
| nn.InstanceNorm2d, |
| nn.GroupNorm, |
| nn.BatchNorm3d, |
| ) |
| def initWeight(Module): |
| if Module is None: |
| return |
| elif isinstance(Module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)): |
| nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
| if Module.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
| if fan_in != 0: |
| bound = 1 / math.sqrt(fan_in) |
| nn.init.uniform_(Module.bias, -bound, bound) |
| elif isinstance(Module, NormLayerTuple): |
| if Module.weight is not None: |
| nn.init.ones_(Module.weight) |
| if Module.bias is not None: |
| nn.init.zeros_(Module.bias) |
| elif isinstance(Module, nn.Linear): |
| nn.init.kaiming_uniform_(Module.weight, a=math.sqrt(5)) |
| if Module.bias is not None: |
| fan_in, _ = nn.init._calculate_fan_in_and_fan_out(Module.weight) |
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 |
| nn.init.uniform_(Module.bias, -bound, bound) |
| elif isinstance(Module, (nn.Sequential, nn.ModuleList)): |
| for m in Module: |
| initWeight(m) |
| elif list(Module.children()): |
| for m in Module.children(): |
| initWeight(m) |
| class Attention(nn.Module): |
| def __init__( |
| self, |
| InChannels: int, |
| HidChannels: int = None, |
| SqueezeFactor: int = 4, |
| PoolRes: list = [1, 2, 3], |
| Act: Callable[..., nn.Module] = nn.ReLU, |
| ScaleAct: Callable[..., nn.Module] = nn.Sigmoid, |
| MoCOrder: bool = True, |
| **kwargs: Any, |
| ) -> None: |
| super().__init__() |
| if HidChannels is None: |
| HidChannels = max(makeDivisible(InChannels // SqueezeFactor, 8), 32) |
|
|
| AllPoolRes = PoolRes + [1] if 1 not in PoolRes else PoolRes |
| for k in AllPoolRes: |
| Pooling = AdaptiveAvgPool2d(k) |
| setMethod(self, 'Pool%d' % k, Pooling) |
|
|
| self.SELayer = nn.Sequential( |
| BaseConv2d(InChannels, HidChannels, 1, ActLayer=Act), |
| BaseConv2d(HidChannels, InChannels, 1, ActLayer=ScaleAct), |
| ) |
|
|
| self.PoolRes = PoolRes |
| self.MoCOrder = MoCOrder |
|
|
| def RandomSample(self, x: Tensor) -> Tensor: |
| if self.training: |
| PoolKeep = np.random.choice(self.PoolRes) |
| x1 = shuffleTensor(x)[0] if self.MoCOrder else x |
| AttnMap: Tensor = callMethod(self, 'Pool%d' % PoolKeep)(x1) |
| if AttnMap.shape[-1] > 1: |
| AttnMap = AttnMap.flatten(2) |
| AttnMap = AttnMap[:, :, torch.randperm(AttnMap.shape[-1])[0]] |
| AttnMap = AttnMap[:, :, None, None] |
| else: |
| AttnMap: Tensor = callMethod(self, 'Pool%d' % 1)(x) |
|
|
| return AttnMap |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| AttnMap = self.RandomSample(x) |
| return x * self.SELayer(AttnMap) |
|
|
| def channel_shuffle(x, groups): |
| batchsize, num_channels, height, width = x.data.size() |
| channels_per_group = num_channels // groups |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| x = x.view(batchsize, -1, height, width) |
| return x |
| class GLFA(nn.Module): |
| def __init__(self, in_channels): |
| super(GLFA, self).__init__() |
| self.in_channels = in_channels |
| self.out_channels = in_channels |
| self.conv_1 = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, padding=1, kernel_size=3, dilation=1), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| self.conv_2 = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, padding=2, kernel_size=3, dilation=2), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| self.conv_3 = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, padding=3, kernel_size=3, dilation=3), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| self.conv_4 = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, padding=4, kernel_size=3, dilation=4), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| self.fuse = nn.Sequential( |
| nn.Conv2d(in_channels * 4, in_channels, kernel_size=1, padding=0), |
| nn.BatchNorm2d(in_channels), |
| nn.ReLU(inplace=True) |
| ) |
| self.mca = Attention(InChannels=in_channels, HidChannels=16) |
| def forward(self, x): |
| d = x |
| c1 = self.conv_1(x) |
| c2 = self.conv_2(x) |
| c3 = self.conv_3(x) |
| c4 = self.conv_4(x) |
| cat = torch.cat([c1, c2, c3, c4], dim=1) |
| cat = channel_shuffle(cat, groups=4) |
| M= self.fuse(cat) |
| O = self.mca(M) |
| return O + d |
|
|