"""ChangeFormer — Transformer-based change detection model. Implements a hierarchical vision transformer (MiT-B1 style) with shared-weight Siamese encoder and MLP decoder for change detection. Based on: "A Transformer-Based Siamese Network for Change Detection" (arXiv:2201.01293). """ from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class OverlapPatchEmbed(nn.Module): """Overlapping patch embedding for hierarchical feature extraction. Args: in_channels: Number of input channels. embed_dim: Embedding dimension. patch_size: Patch size for convolution. stride: Stride for convolution. """ def __init__( self, in_channels: int = 3, embed_dim: int = 64, patch_size: int = 7, stride: int = 4, ) -> None: super().__init__() self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size // 2, ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: """Forward pass. Args: x: Input tensor [B, C, H, W]. Returns: Tuple of (tokens [B, N, D], height, width). """ x = self.proj(x) _, _, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") x = self.norm(x) return x, h, w class EfficientSelfAttention(nn.Module): """Efficient self-attention with spatial reduction. Args: dim: Input dimension. num_heads: Number of attention heads. sr_ratio: Spatial reduction ratio. """ def __init__(self, dim: int, num_heads: int = 1, sr_ratio: int = 8) -> None: super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.q = nn.Linear(dim, dim) self.kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) # Spatial reduction self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.sr_norm = nn.LayerNorm(dim) def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: """Forward pass. Args: x: Input tokens [B, N, C]. h: Feature map height. w: Feature map width. Returns: Output tokens [B, N, C]. """ b, n, c = x.shape q = self.q(x).reshape(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x_ = self.sr(x_) x_ = rearrange(x_, "b c h w -> b (h w) c") x_ = self.sr_norm(x_) else: x_ = x kv = self.kv(x_).reshape(b, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(b, n, c) out = self.proj(out) return out class MixFFN(nn.Module): """Mix Feed-Forward Network with depthwise convolution. Args: dim: Input/output dimension. mlp_ratio: Expansion ratio for hidden dimension. """ def __init__(self, dim: int, mlp_ratio: int = 4) -> None: super().__init__() hidden = dim * mlp_ratio self.fc1 = nn.Linear(dim, hidden) self.dwconv = nn.Conv2d(hidden, hidden, 3, 1, 1, groups=hidden) self.fc2 = nn.Linear(hidden, dim) self.act = nn.GELU() def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: """Forward pass. Args: x: Input tokens [B, N, C]. h: Feature map height. w: Feature map width. Returns: Output tokens [B, N, C]. """ x = self.fc1(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) x = self.act(self.dwconv(x)) x = rearrange(x, "b c h w -> b (h w) c") x = self.fc2(x) return x class TransformerBlock(nn.Module): """Single transformer block with efficient attention and MixFFN. Args: dim: Feature dimension. num_heads: Number of attention heads. mlp_ratio: MLP expansion ratio. sr_ratio: Spatial reduction ratio for attention. """ def __init__( self, dim: int, num_heads: int = 1, mlp_ratio: int = 4, sr_ratio: int = 8, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = EfficientSelfAttention(dim, num_heads, sr_ratio) self.norm2 = nn.LayerNorm(dim) self.ffn = MixFFN(dim, mlp_ratio) def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: """Forward pass. Args: x: Input tokens [B, N, C]. h: Feature map height. w: Feature map width. Returns: Output tokens [B, N, C]. """ x = x + self.attn(self.norm1(x), h, w) x = x + self.ffn(self.norm2(x), h, w) return x class MiTEncoder(nn.Module): """Mix Transformer (MiT) encoder — hierarchical vision transformer. Args: embed_dims: Embedding dimensions at each stage. num_heads: Number of attention heads at each stage. mlp_ratios: MLP expansion ratios at each stage. depths: Number of transformer blocks at each stage. """ def __init__( self, embed_dims: List[int] = [64, 128, 320, 512], num_heads: List[int] = [1, 2, 5, 8], mlp_ratios: List[int] = [8, 8, 4, 4], depths: List[int] = [2, 2, 2, 2], ) -> None: super().__init__() self.num_stages = len(embed_dims) sr_ratios = [8, 4, 2, 1] patch_sizes = [7, 3, 3, 3] strides = [4, 2, 2, 2] self.patch_embeds = nn.ModuleList() self.blocks = nn.ModuleList() self.norms = nn.ModuleList() for i in range(self.num_stages): in_ch = 3 if i == 0 else embed_dims[i - 1] self.patch_embeds.append( OverlapPatchEmbed(in_ch, embed_dims[i], patch_sizes[i], strides[i]) ) self.blocks.append( nn.ModuleList([ TransformerBlock(embed_dims[i], num_heads[i], mlp_ratios[i], sr_ratios[i]) for _ in range(depths[i]) ]) ) self.norms.append(nn.LayerNorm(embed_dims[i])) def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """Extract hierarchical features. Args: x: Input image [B, 3, H, W]. Returns: List of feature maps at each stage [B, C_i, H_i, W_i]. """ features = [] for i in range(self.num_stages): x, h, w = self.patch_embeds[i](x) for blk in self.blocks[i]: x = blk(x, h, w) x = self.norms[i](x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) features.append(x) return features class MLPDecoder(nn.Module): """MLP-based decoder that fuses multi-scale difference features. Args: embed_dims: Embedding dimensions from each encoder stage. out_channels: Number of output channels (1 for binary change mask). """ def __init__( self, embed_dims: List[int] = [64, 128, 320, 512], out_channels: int = 1, ) -> None: super().__init__() unified_dim = embed_dims[0] self.linear_projections = nn.ModuleList([ nn.Conv2d(dim, unified_dim, kernel_size=1) for dim in embed_dims ]) self.fuse = nn.Sequential( nn.Conv2d(unified_dim * len(embed_dims), unified_dim, kernel_size=1), nn.BatchNorm2d(unified_dim), nn.ReLU(inplace=True), ) self.head = nn.Conv2d(unified_dim, out_channels, kernel_size=1) def forward(self, features: List[torch.Tensor], target_size: Tuple[int, int]) -> torch.Tensor: """Forward pass. Args: features: List of difference feature maps. target_size: (H, W) of the desired output. Returns: Logits [B, 1, H, W]. """ projected = [] for i, (feat, proj) in enumerate(zip(features, self.linear_projections)): p = proj(feat) p = F.interpolate(p, size=target_size, mode="bilinear", align_corners=False) projected.append(p) fused = self.fuse(torch.cat(projected, dim=1)) out = self.head(fused) return out class ChangeFormer(nn.Module): """ChangeFormer: Transformer-based Siamese network for change detection. Args: embed_dims: Embedding dims at each hierarchical stage. num_heads: Attention heads at each stage. mlp_ratios: MLP expansion ratios at each stage. depths: Transformer block counts at each stage. pretrained_backbone: Whether to load pretrained MiT weights. """ def __init__( self, embed_dims: List[int] = [64, 128, 320, 512], num_heads: List[int] = [1, 2, 5, 8], mlp_ratios: List[int] = [8, 8, 4, 4], depths: List[int] = [2, 2, 2, 2], pretrained_backbone: bool = True, ) -> None: super().__init__() # Shared Siamese encoder self.encoder = MiTEncoder(embed_dims, num_heads, mlp_ratios, depths) # MLP decoder self.decoder = MLPDecoder(embed_dims, out_channels=1) # TODO: Load pretrained MiT-B1 weights if pretrained_backbone is True def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x1: Before image [B, 3, 256, 256]. x2: After image [B, 3, 256, 256]. Returns: Raw logits [B, 1, 256, 256]. """ # Extract hierarchical features feats_1 = self.encoder(x1) feats_2 = self.encoder(x2) # Compute difference at each scale diff_feats = [torch.abs(f1 - f2) for f1, f2 in zip(feats_1, feats_2)] # Decode to change mask target_size = (x1.shape[2], x1.shape[3]) out = self.decoder(diff_feats, target_size) return out if __name__ == "__main__": # Quick sanity check model = ChangeFormer(pretrained_backbone=False) x1 = torch.randn(1, 3, 256, 256) x2 = torch.randn(1, 3, 256, 256) out = model(x1, x2) print(f"Input: {x1.shape}, Output: {out.shape}") assert out.shape == (1, 1, 256, 256), f"Unexpected shape: {out.shape}" print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")