| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from models.modules import ConvNextBlock, Decoder, LayerNorm, SimDecoder, UpSampleConvnext |
| import torch.distributed as dist |
| from models.revcol_function import ReverseFunction |
| from timm.models.layers import trunc_normal_ |
|
|
| class Fusion(nn.Module): |
| def __init__(self, level, channels, first_col) -> None: |
| super().__init__() |
| |
| self.level = level |
| self.first_col = first_col |
| self.down = nn.Sequential( |
| nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2), |
| LayerNorm(channels[level], eps=1e-6, data_format="channels_first"), |
| ) if level in [1, 2, 3] else nn.Identity() |
| if not first_col: |
| self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity() |
|
|
| def forward(self, *args): |
|
|
| c_down, c_up = args |
| |
| if self.first_col: |
| x = self.down(c_down) |
| return x |
| |
| if self.level == 3: |
| x = self.down(c_down) |
| else: |
| x = self.up(c_up) + self.down(c_down) |
| return x |
|
|
| class Level(nn.Module): |
| def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None: |
| super().__init__() |
| countlayer = sum(layers[:level]) |
| expansion = 4 |
| self.fusion = Fusion(level, channels, first_col) |
| modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])] |
| self.blocks = nn.Sequential(*modules) |
| def forward(self, *args): |
| x = self.fusion(*args) |
| x = self.blocks(x) |
| return x |
|
|
| class SubNet(nn.Module): |
| def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None: |
| super().__init__() |
| shortcut_scale_init_value = 0.5 |
| self.save_memory = save_memory |
| self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
|
| self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates) |
|
|
| self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| def _forward_nonreverse(self, *args): |
| x, c0, c1, c2, c3= args |
|
|
| c0 = (self.alpha0)*c0 + self.level0(x, c1) |
| c1 = (self.alpha1)*c1 + self.level1(c0, c2) |
| c2 = (self.alpha2)*c2 + self.level2(c1, c3) |
| c3 = (self.alpha3)*c3 + self.level3(c2, None) |
| return c0, c1, c2, c3 |
|
|
| def _forward_reverse(self, *args): |
|
|
| local_funs = [self.level0, self.level1, self.level2, self.level3] |
| alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3] |
| _, c0, c1, c2, c3 = ReverseFunction.apply( |
| local_funs, alpha, *args) |
|
|
| return c0, c1, c2, c3 |
|
|
| def forward(self, *args): |
| |
| self._clamp_abs(self.alpha0.data, 1e-3) |
| self._clamp_abs(self.alpha1.data, 1e-3) |
| self._clamp_abs(self.alpha2.data, 1e-3) |
| self._clamp_abs(self.alpha3.data, 1e-3) |
| |
| if self.save_memory: |
| return self._forward_reverse(*args) |
| else: |
| return self._forward_nonreverse(*args) |
|
|
| def _clamp_abs(self, data, value): |
| with torch.no_grad(): |
| sign=data.sign() |
| data.abs_().clamp_(value) |
| data*=sign |
|
|
|
|
| class Classifier(nn.Module): |
| def __init__(self, in_channels, num_classes): |
| super().__init__() |
|
|
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(in_channels, eps=1e-6), |
| nn.Linear(in_channels, num_classes), |
| ) |
|
|
| def forward(self, x): |
| x = self.avgpool(x) |
| x = x.view(x.size(0), -1) |
| x = self.classifier(x) |
| return x |
|
|
| class FullNet(nn.Module): |
| def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, inter_supv=True, head_init_scale=None) -> None: |
| super().__init__() |
| self.num_subnet = num_subnet |
| self.inter_supv = inter_supv |
| self.channels = channels |
| self.layers = layers |
|
|
| self.stem = nn.Sequential( |
| nn.Conv2d(3, channels[0], kernel_size=4, stride=4), |
| LayerNorm(channels[0], eps=1e-6, data_format="channels_first") |
| ) |
|
|
| dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))] |
| for i in range(num_subnet): |
| first_col = True if i == 0 else False |
| self.add_module(f'subnet{str(i)}', SubNet( |
| channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory)) |
|
|
| if not inter_supv: |
| self.cls = Classifier(in_channels=channels[-1], num_classes=num_classes) |
| else: |
| self.cls_blocks = nn.ModuleList([Classifier(in_channels=channels[-1], num_classes=num_classes) for _ in range(4) ]) |
| if num_classes<=1000: |
| channels.reverse() |
| self.decoder_blocks = nn.ModuleList([Decoder(depth=[1,1,1,1], dim=channels, block_type=ConvNextBlock, kernel_size = 3) for _ in range(3) ]) |
| else: |
| self.decoder_blocks = nn.ModuleList([SimDecoder(in_channel=channels[-1], encoder_stride=32) for _ in range(3) ]) |
|
|
| self.apply(self._init_weights) |
| |
| if head_init_scale: |
| print(f'Head_init_scale: {head_init_scale}') |
| self.cls.classifier._modules['1'].weight.data.mul_(head_init_scale) |
| self.cls.classifier._modules['1'].bias.data.mul_(head_init_scale) |
| |
|
|
| def forward(self, x): |
|
|
| if self.inter_supv: |
| return self._forward_intermediate_supervision(x) |
| else: |
| c0, c1, c2, c3 = 0, 0, 0, 0 |
| x = self.stem(x) |
| for i in range(self.num_subnet): |
| c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3) |
| return [self.cls(c3)], None |
| |
| def _forward_intermediate_supervision(self, x): |
| x_cls_out = [] |
| x_img_out = [] |
| c0, c1, c2, c3 = 0, 0, 0, 0 |
| interval = self.num_subnet//4 |
| |
| x = self.stem(x) |
| for i in range(self.num_subnet): |
| c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3) |
| if (i+1) % interval == 0: |
| x_cls_out.append(self.cls_blocks[i//interval](c3)) |
| if i != self.num_subnet-1: |
| x_img_out.append(self.decoder_blocks[i//interval](c3)) |
|
|
| return x_cls_out, x_img_out |
| |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Conv2d): |
| trunc_normal_(module.weight, std=.02) |
| nn.init.constant_(module.bias, 0) |
| elif isinstance(module, nn.Linear): |
| trunc_normal_(module.weight, std=.02) |
| nn.init.constant_(module.bias, 0) |
| |
| |
|
|
| def revcol_tiny(save_memory, inter_supv=True, drop_path=0.1, num_classes=1000, kernel_size = 3): |
| channels = [64, 128, 256, 512] |
| layers = [2, 2, 4, 2] |
| num_subnet = 4 |
| return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size) |
|
|
| |
|
|
| def revcol_small(save_memory, inter_supv=True, drop_path=0.3, num_classes=1000, kernel_size = 3): |
| channels = [64, 128, 256, 512] |
| layers = [2, 2, 4, 2] |
| num_subnet = 8 |
| return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, kernel_size=kernel_size) |
|
|
| |
|
|
| def revcol_base(save_memory, inter_supv=True, drop_path=0.4, num_classes=1000, kernel_size = 3, head_init_scale=None): |
| channels = [72, 144, 288, 576] |
| layers = [1, 1, 3, 2] |
| num_subnet = 16 |
| return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
|
|
| |
|
|
| def revcol_large(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None): |
| channels = [128, 256, 512, 1024] |
| layers = [1, 2, 6, 2] |
| num_subnet = 8 |
| return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
| |
| def revcol_xlarge(save_memory, inter_supv=True, drop_path=0.5, num_classes=1000, kernel_size = 3, head_init_scale=None): |
| channels = [224, 448, 896, 1792] |
| layers = [1, 2, 6, 2] |
| num_subnet = 8 |
| return FullNet(channels, layers, num_subnet, num_classes=num_classes, drop_path = drop_path, save_memory=save_memory, inter_supv=inter_supv, head_init_scale=head_init_scale, kernel_size=kernel_size) |
|
|
|
|