import torch import torch.nn as nn import torch.nn.functional as F import numbers from dataclasses import dataclass, asdict from einops import rearrange class UpSample(nn.Module): """ UpSampling block using PixelShuffle """ def __init__(self, filters=64): super().__init__() self.conv = nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, padding=0, bias=True) self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2) def forward(self, x): x = self.conv(x) x = self.pixel_shuffle(x) return x ## DownSampling block class DownSample(nn.Module): """ DownSampling block using PixelUnshuffle """ def __init__(self, filters=64): super().__init__() self.conv = nn.Conv2d(filters, filters // 2, kernel_size=1, stride=1, padding=0, bias=True) self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=2) def forward(self, x): """ SHAPE (B, C, H, W) -> SHAPE (B, C/4, H/2, W/2) """ x = self.conv(x) x = self.pixel_unshuffle(x) return x # Custom LayerNormalization class BiasFree_LayerNorm(nn.Module): """ Bias-Free Layer Normalization """ def __init__(self, normalized_shape): super().__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): x = x.contiguous() sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma+1e-5) * self.weight class WithBias_LayerNorm(nn.Module): """ With-Bias Layer Normalization """ def __init__(self, normalized_shape): super().__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): x = x.contiguous() mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias class LayerNorm(nn.Module): """ Layer Normalization supporting two types: BiasFree and WithBias """ def __init__(self, dim, LayerNorm_type, out_4d=True): super().__init__() if LayerNorm_type =='BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) self.out_4d = out_4d def to_3d(self, x): # Convert (B, C, H, W) to (B, H*W, C) if len(x.shape) == 3: return x elif len(x.shape) == 4: return rearrange(x, 'b c h w -> b (h w) c') else: raise ValueError("Input must be a 3D or 4D tensor") def to_4d(self, x, h, w): # Convert (B, H*W, C) to (B, C, H, W) if len(x.shape) == 4: return x elif len(x.shape) == 3: return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) else: raise ValueError("Input must be a 3D or 4D tensor") def forward(self, x): if self.out_4d: h, w = x.shape[-2:] return self.to_4d(self.body(self.to_3d(x)), h, w) else: return self.body(x) class RepConv3(nn.Module): def __init__(self, in_channels, out_channels, groups, deploy=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.groups = groups self.deploy = deploy self.reparam = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups) if not deploy: self.conv_3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups) self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups) self.conv_1x3 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1), groups=groups) self.conv_3x1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), padding=(1, 0), groups=groups) self.conv_1x1_branch = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups, bias=False) self.conv_3x3_branch = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups, bias=False) else: self._delete_branches() def _delete_branches(self): for name in ['conv_3x3','conv_1x1','conv_1x3','conv_3x1', 'conv_1x1_branch', 'conv_3x3_branch']: if hasattr(self, name): delattr(self, name) def fuse(self, delete_branches=True): if self.deploy: return # Extract weights and biases conv_3x3_w, conv_3x3_b = self.conv_3x3.weight, self.conv_3x3.bias conv_1x1_w, conv_1x1_b = self.conv_1x1.weight, self.conv_1x1.bias conv_1x3_w, conv_1x3_b = self.conv_1x3.weight, self.conv_1x3.bias conv_3x1_w, conv_3x1_b = self.conv_3x1.weight, self.conv_3x1.bias conv_1x1_branch_w, conv_3x3_branch_w = self.conv_1x1_branch.weight, self.conv_3x3_branch.weight # Pad the smaller kernels to 3x3 conv_1x1_w_pad = F.pad(conv_1x1_w, [1, 1, 1, 1]) conv_1x3_w_pad = F.pad(conv_1x3_w, [0, 0, 1, 1]) conv_3x1_w_pad = F.pad(conv_3x1_w, [1, 1, 0, 0]) if self.groups == 1: conv_1x1_3x3_w_pad = F.conv2d(conv_3x3_branch_w, conv_1x1_branch_w.permute(1, 0, 2, 3)) else: w_slices = [] conv_1x1_branch_w_T = conv_1x1_branch_w.permute(1, 0, 2, 3) in_channels_per_group = self.in_channels // self.groups out_channels_per_group = self.out_channels // self.groups for g in range(self.groups): # Slice the transposed 1x1 weights for this group's channels conv_1x1_branch_w_T_slice = conv_1x1_branch_w_T[:, g*in_channels_per_group:(g+1)*in_channels_per_group, :, :] # Slice the 3x3 weights for this group's output channels conv_3x3_branch_w_slice = conv_3x3_branch_w[g*out_channels_per_group:(g+1)*out_channels_per_group, :, :, :] w_slices.append(F.conv2d(conv_3x3_branch_w_slice, conv_1x1_branch_w_T_slice)) conv_1x1_3x3_w_pad = torch.cat(w_slices, dim=0) # Fuse weights and biases conv_w = conv_3x3_w + conv_1x1_w_pad + conv_1x3_w_pad + conv_3x1_w_pad + conv_1x1_3x3_w_pad if conv_3x3_b is None: conv_3x3_b = torch.zeros(self.out_channels, device=conv_w.device) conv_b = conv_3x3_b + conv_1x1_b + conv_1x3_b + conv_3x1_b self.reparam.weight.data.copy_(conv_w) self.reparam.bias.data.copy_(conv_b) # Delete the original branches if delete_branches: self._delete_branches() # Set deploy flag self.deploy = True def forward(self, x): if self.deploy: return self.reparam(x) else: return self.conv_3x3(x) + self.conv_1x1(x) + self.conv_1x3(x) + self.conv_3x1(x) + self.conv_3x3_branch(self.conv_1x1_branch(x)) from monarch_attn import MonarchAttention @dataclass class RepAttnConfig: dim: int num_heads: int = 8 block_size: int = 16 num_steps: int = 2 pad_type: str = "pre" impl: str = "torch" deploy: bool = False class RepAttn(nn.Module): """ Re-parameterizable Attention Block using MonarchAttention as the core attention mechanism.""" def __init__(self, dim, num_heads=8, block_size=14, num_steps=1, pad_type="pre", impl="torch", deploy=False): super().__init__() self.num_heads = num_heads self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1) self.monarch_attn = MonarchAttention( block_size=block_size, num_steps=num_steps, pad_type=pad_type, impl=impl ) if deploy: self.attn_fn = self.monarch_attn else: self.attn_fn = self.common_attn self.proj = nn.Conv2d(dim, dim, kernel_size=1) self.deploy = deploy def common_attn(self, q, k, v): """ Scaled Dot-Product Attention """ scale = (q.shape[-1]) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) out = attn @ v return out @torch.no_grad() def fuse(self): if not self.deploy: self.attn_fn = self.monarch_attn self.deploy = True def forward(self, x): B, C, H, W = x.shape qkv = self.qkv(x) q, k, v = torch.chunk(qkv, 3, dim=1) q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) attn_out = self.attn_fn(q, k, v) attn_out = rearrange(attn_out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=H, w=W) out = self.proj(attn_out) return out @dataclass class FFNConfig: dim: int expansion_factor: int = 1 deploy: bool = False class RepFFN(nn.Module): def __init__(self, dim, expansion_factor=1, deploy=False): super().__init__() hidden_features = int(dim * expansion_factor) self.project_in = RepConv3(dim, hidden_features, groups=1, deploy=deploy) self.dwconv = RepConv3(hidden_features, hidden_features*2, groups=hidden_features, deploy=deploy) self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1) @torch.no_grad() def fuse(self): self.project_in.fuse() self.dwconv.fuse() def forward(self, x): x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) x = F.gelu(x1) * x2 x = self.project_out(x) return x class SkipConnection(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv2d(dim*2, dim, kernel_size=1) def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) x = self.conv(x) return x class RepTransformerBlock(nn.Module): def __init__(self, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'): super().__init__() self.rep_attn = RepAttn(**asdict(rep_attn_cfg)) self.rep_ffn = RepFFN(**asdict(ffn_cfg)) self.norm1 = LayerNorm(rep_attn_cfg.dim, norm_type) self.norm2 = LayerNorm(rep_attn_cfg.dim, norm_type) @torch.no_grad() def fuse(self): self.rep_attn.fuse() self.rep_ffn.fuse() def forward(self, x): x = x + self.rep_attn(self.norm1(x)) x = x + self.rep_ffn(self.norm2(x)) return x class Block(nn.Module): def __init__(self, num_block, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'): super().__init__() self.num_block = num_block self.blocks = nn.ModuleList([ RepTransformerBlock(rep_attn_cfg, ffn_cfg, norm_type) for _ in range(num_block) ]) @torch.no_grad() def fuse(self): for block in self.blocks: block.fuse() def forward(self, x): for block in self.blocks: x = block(x) return x class ColorComicNet(nn.Module): """ Main model implementation """ def __init__(self, input_shape=(3, 1024, 1024), output_channels=3, deploy=False, dims=[48, 96, 192, 384], num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 2, 4], bias=True, last_act=None): super().__init__() assert len(dims) == len(num_blocks) == len(num_heads), "Length of dims, num_blocks and num_heads must be the same" self.input_shape = input_shape self.output_channels = output_channels self.deploy = deploy self.dims = dims self.num_blocks = num_blocks self.bias = bias self.num_heads = num_heads # Extractor self.stem = nn.Conv2d(input_shape[0], dims[0], kernel_size=7, stride=4, padding=3, bias=bias) # Encoder layers = [] down_convs = [] for idx in range(len(dims)): attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx]) block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias') if idx < len(dims) - 1: down_convs.append(DownSample(dims[idx])) layers.append(block) self.bottleneck = layers[-1] # Last encoder layer as bottleneck self.encoder = nn.ModuleList(layers[:-1]) self.downsample = nn.ModuleList(down_convs) # Decoder layers = [] up_convs = [] skip_connections = [] for idx in range(len(dims)-2, -1, -1): attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx]) # print(f"Decoder layer {idx}: shape {l_shape}") up_conv = UpSample(dims[idx+1]) block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias') layers.append(block) up_convs.append(up_conv) skip_connections.append(SkipConnection(dims[idx])) self.decoder = nn.ModuleList(layers) self.up_sample = nn.ModuleList(up_convs) self.skip = nn.ModuleList(skip_connections) # Head self.head = nn.Sequential( RepConv3(dims[0], dims[0]//2, 1, deploy=deploy), nn.GELU(), nn.Conv2d(dims[0]//2, output_channels, kernel_size=1, bias=bias), ) self.last_act = last_act if last_act is not None else nn.Identity() @torch.no_grad() def fuse(self): for block in self.encoder: block.fuse() self.bottleneck.fuse() for block in self.decoder: block.fuse() for conv in self.head: if isinstance(conv, RepConv3): conv.fuse() def build_cfg(self, dim, head): # RepAttn config attn_cfg = RepAttnConfig( dim=dim, num_heads=head, block_size=12, num_steps=2, pad_type="pre", impl="torch", deploy=self.deploy ) ## FFN config ffn_cfg = FFNConfig( dim=dim, expansion_factor=1, ) return attn_cfg, ffn_cfg def forward(self, x): """ x: (B, C, H, W) """ res = x x = self.stem(x) feats = [] for blk, down in zip(self.encoder, self.downsample): x = blk(x) feats.append(x) x = down(x) x = self.bottleneck(x) for blk, up, skip in zip(self.decoder, self.up_sample, self.skip): x = up(x) cur_feat = feats.pop() x = skip(x, cur_feat) x = blk(x) x = F.interpolate(x, scale_factor=4, mode='bilinear') x = self.head(x) + res x = self.last_act(x) return x # Example model configuration MODEL_CFG = { 'input_shape': (3, 512, 512), 'dims': [24, 48, 96, 192], 'num_blocks': [1, 2, 2, 4], 'num_heads': [1, 2, 4, 8], 'bias': True, 'last_act': nn.Tanh(), 'deploy': False }