| import torch |
| import torch.nn as nn |
|
|
| from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer |
| from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt |
| from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP |
| from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine |
|
|
|
|
| class DDColor(nn.Module): |
| def __init__( |
| self, |
| encoder_name='convnext-l', |
| decoder_name='MultiScaleColorDecoder', |
| num_input_channels=3, |
| input_size=(256, 256), |
| nf=512, |
| num_output_channels=3, |
| last_norm='Weight', |
| do_normalize=False, |
| num_queries=256, |
| num_scales=3, |
| dec_layers=9, |
| ): |
| super().__init__() |
|
|
| self.encoder = ImageEncoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3']) |
| self.encoder.eval() |
| test_input = torch.randn(1, num_input_channels, *input_size) |
| self.encoder(test_input) |
|
|
| self.decoder = DuelDecoder( |
| self.encoder.hooks, |
| nf=nf, |
| last_norm=last_norm, |
| num_queries=num_queries, |
| num_scales=num_scales, |
| dec_layers=dec_layers, |
| decoder_name=decoder_name |
| ) |
|
|
| self.refine_net = nn.Sequential( |
| custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral) |
| ) |
| |
| self.do_normalize = do_normalize |
| self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
| self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
| def normalize(self, img): |
| return (img - self.mean) / self.std |
|
|
| def denormalize(self, img): |
| return img * self.std + self.mean |
|
|
| def forward(self, x): |
| if x.shape[1] == 3: |
| x = self.normalize(x) |
| |
| self.encoder(x) |
| out_feat = self.decoder() |
| coarse_input = torch.cat([out_feat, x], dim=1) |
| out = self.refine_net(coarse_input) |
|
|
| if self.do_normalize: |
| out = self.denormalize(out) |
| return out |
|
|
|
|
| class ImageEncoder(nn.Module): |
| def __init__(self, encoder_name, hook_names): |
| super().__init__() |
|
|
| assert encoder_name == 'convnext-t' or encoder_name == 'convnext-l' |
| if encoder_name == 'convnext-t': |
| self.arch = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]) |
| elif encoder_name == 'convnext-l': |
| self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]) |
| else: |
| raise NotImplementedError |
|
|
| self.encoder_name = encoder_name |
| self.hook_names = hook_names |
| self.hooks = self.setup_hooks() |
|
|
| def setup_hooks(self): |
| hooks = [Hook(self.arch._modules[name]) for name in self.hook_names] |
| return hooks |
|
|
| def forward(self, x): |
| return self.arch(x) |
|
|
|
|
| class DuelDecoder(nn.Module): |
| def __init__( |
| self, |
| hooks, |
| nf=512, |
| blur=True, |
| last_norm='Weight', |
| num_queries=256, |
| num_scales=3, |
| dec_layers=9, |
| decoder_name='MultiScaleColorDecoder', |
| ): |
| super().__init__() |
| self.hooks = hooks |
| self.nf = nf |
| self.blur = blur |
| self.last_norm = getattr(NormType, last_norm) |
| self.decoder_name = decoder_name |
|
|
| self.layers = self.make_layers() |
| embed_dim = nf // 2 |
| self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4) |
| |
| assert decoder_name == 'MultiScaleColorDecoder' |
| self.color_decoder = MultiScaleColorDecoder( |
| in_channels=[512, 512, 256], |
| num_queries=num_queries, |
| num_scales=num_scales, |
| dec_layers=dec_layers, |
| ) |
|
|
| def make_layers(self): |
| decoder_layers = [] |
| in_c = self.hooks[-1].feature.shape[1] |
| out_c = self.nf |
|
|
| setup_hooks = self.hooks[-2::-1] |
| for layer_index, hook in enumerate(setup_hooks): |
| feature_c = hook.feature.shape[1] |
| if layer_index == len(setup_hooks) - 1: |
| out_c = out_c // 2 |
| decoder_layers.append( |
| UnetBlockWide( |
| in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral)) |
| in_c = out_c |
|
|
| return nn.Sequential(*decoder_layers) |
|
|
| def forward(self): |
| encode_feat = self.hooks[-1].feature |
| out0 = self.layers[0](encode_feat) |
| out1 = self.layers[1](out0) |
| out2 = self.layers[2](out1) |
| out3 = self.last_shuf(out2) |
|
|
| return self.color_decoder([out0, out1, out2], out3) |
|
|
|
|
| class MultiScaleColorDecoder(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| hidden_dim=256, |
| num_queries=100, |
| nheads=8, |
| dim_feedforward=2048, |
| dec_layers=9, |
| pre_norm=False, |
| color_embed_dim=256, |
| enforce_input_project=True, |
| num_scales=3, |
| ): |
| super().__init__() |
|
|
| self.hidden_dim = hidden_dim |
| self.num_queries = num_queries |
| self.num_layers = dec_layers |
| self.num_feature_levels = num_scales |
|
|
| |
| self.pe_layer = PositionEmbeddingSine(hidden_dim // 2, normalize=True) |
| |
| |
| self.query_feat = nn.Embedding(num_queries, hidden_dim) |
| self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
| |
| self.level_embed = nn.Embedding(num_scales, hidden_dim) |
|
|
| |
| self.input_proj = nn.ModuleList( |
| [self._make_input_proj(in_ch, hidden_dim, enforce_input_project) for in_ch in in_channels] |
| ) |
|
|
| |
| self.transformer_self_attention_layers = nn.ModuleList() |
| self.transformer_cross_attention_layers = nn.ModuleList() |
| self.transformer_ffn_layers = nn.ModuleList() |
|
|
| for _ in range(dec_layers): |
| self.transformer_self_attention_layers.append( |
| SelfAttentionLayer( |
| d_model=hidden_dim, |
| nhead=nheads, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| ) |
| ) |
| self.transformer_cross_attention_layers.append( |
| CrossAttentionLayer( |
| d_model=hidden_dim, |
| nhead=nheads, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| ) |
| ) |
| self.transformer_ffn_layers.append( |
| FFNLayer( |
| d_model=hidden_dim, |
| dim_feedforward=dim_feedforward, |
| dropout=0.0, |
| normalize_before=pre_norm, |
| ) |
| ) |
|
|
| |
| self.decoder_norm = nn.LayerNorm(hidden_dim) |
| |
| |
| self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3) |
|
|
| def forward(self, x, img_features): |
| assert len(x) == self.num_feature_levels |
|
|
| src, pos = self._get_src_and_pos(x) |
|
|
| bs = src[0].shape[1] |
|
|
| |
| query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) |
| output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) |
| |
| for i in range(self.num_layers): |
| level_index = i % self.num_feature_levels |
| |
| output = self.transformer_cross_attention_layers[i]( |
| output, src[level_index], |
| memory_mask=None, |
| memory_key_padding_mask=None, |
| pos=pos[level_index], query_pos=query_embed |
| ) |
| output = self.transformer_self_attention_layers[i]( |
| output, tgt_mask=None, |
| tgt_key_padding_mask=None, |
| query_pos=query_embed |
| ) |
| |
| output = self.transformer_ffn_layers[i]( |
| output |
| ) |
|
|
| decoder_output = self.decoder_norm(output).transpose(0, 1) |
| color_embed = self.color_embed(decoder_output) |
| |
| out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features) |
|
|
| return out |
|
|
| def _make_input_proj(self, in_ch, hidden_dim, enforce): |
| if in_ch != hidden_dim or enforce: |
| proj = nn.Conv2d(in_ch, hidden_dim, kernel_size=1) |
| nn.init.kaiming_uniform_(proj.weight, a=1) |
| if proj.bias is not None: |
| nn.init.constant_(proj.bias, 0) |
| return proj |
| return nn.Sequential() |
|
|
| def _get_src_and_pos(self, x): |
| src, pos = [], [] |
| for i, feature in enumerate(x): |
| pos.append(self.pe_layer(feature).flatten(2).permute(2, 0, 1)) |
| src.append((self.input_proj[i](feature).flatten(2) + self.level_embed.weight[i][None, :, None]).permute(2, 0, 1)) |
| return src, pos |
|
|