BaseChange / models /changeformer.py
Vedant Jigarbhai Mehta
Initial scaffolding for military base change detection project
b25c087
"""ChangeFormer — Transformer-based change detection model.
Implements a hierarchical vision transformer (MiT-B1 style) with shared-weight
Siamese encoder and MLP decoder for change detection. Based on:
"A Transformer-Based Siamese Network for Change Detection" (arXiv:2201.01293).
"""
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class OverlapPatchEmbed(nn.Module):
"""Overlapping patch embedding for hierarchical feature extraction.
Args:
in_channels: Number of input channels.
embed_dim: Embedding dimension.
patch_size: Patch size for convolution.
stride: Stride for convolution.
"""
def __init__(
self,
in_channels: int = 3,
embed_dim: int = 64,
patch_size: int = 7,
stride: int = 4,
) -> None:
super().__init__()
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=stride,
padding=patch_size // 2,
)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""Forward pass.
Args:
x: Input tensor [B, C, H, W].
Returns:
Tuple of (tokens [B, N, D], height, width).
"""
x = self.proj(x)
_, _, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
x = self.norm(x)
return x, h, w
class EfficientSelfAttention(nn.Module):
"""Efficient self-attention with spatial reduction.
Args:
dim: Input dimension.
num_heads: Number of attention heads.
sr_ratio: Spatial reduction ratio.
"""
def __init__(self, dim: int, num_heads: int = 1, sr_ratio: int = 8) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.q = nn.Linear(dim, dim)
self.kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
# Spatial reduction
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.sr_norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tokens [B, N, C].
h: Feature map height.
w: Feature map width.
Returns:
Output tokens [B, N, C].
"""
b, n, c = x.shape
q = self.q(x).reshape(b, n, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x_ = self.sr(x_)
x_ = rearrange(x_, "b c h w -> b (h w) c")
x_ = self.sr_norm(x_)
else:
x_ = x
kv = self.kv(x_).reshape(b, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(b, n, c)
out = self.proj(out)
return out
class MixFFN(nn.Module):
"""Mix Feed-Forward Network with depthwise convolution.
Args:
dim: Input/output dimension.
mlp_ratio: Expansion ratio for hidden dimension.
"""
def __init__(self, dim: int, mlp_ratio: int = 4) -> None:
super().__init__()
hidden = dim * mlp_ratio
self.fc1 = nn.Linear(dim, hidden)
self.dwconv = nn.Conv2d(hidden, hidden, 3, 1, 1, groups=hidden)
self.fc2 = nn.Linear(hidden, dim)
self.act = nn.GELU()
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tokens [B, N, C].
h: Feature map height.
w: Feature map width.
Returns:
Output tokens [B, N, C].
"""
x = self.fc1(x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.act(self.dwconv(x))
x = rearrange(x, "b c h w -> b (h w) c")
x = self.fc2(x)
return x
class TransformerBlock(nn.Module):
"""Single transformer block with efficient attention and MixFFN.
Args:
dim: Feature dimension.
num_heads: Number of attention heads.
mlp_ratio: MLP expansion ratio.
sr_ratio: Spatial reduction ratio for attention.
"""
def __init__(
self,
dim: int,
num_heads: int = 1,
mlp_ratio: int = 4,
sr_ratio: int = 8,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = EfficientSelfAttention(dim, num_heads, sr_ratio)
self.norm2 = nn.LayerNorm(dim)
self.ffn = MixFFN(dim, mlp_ratio)
def forward(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
"""Forward pass.
Args:
x: Input tokens [B, N, C].
h: Feature map height.
w: Feature map width.
Returns:
Output tokens [B, N, C].
"""
x = x + self.attn(self.norm1(x), h, w)
x = x + self.ffn(self.norm2(x), h, w)
return x
class MiTEncoder(nn.Module):
"""Mix Transformer (MiT) encoder — hierarchical vision transformer.
Args:
embed_dims: Embedding dimensions at each stage.
num_heads: Number of attention heads at each stage.
mlp_ratios: MLP expansion ratios at each stage.
depths: Number of transformer blocks at each stage.
"""
def __init__(
self,
embed_dims: List[int] = [64, 128, 320, 512],
num_heads: List[int] = [1, 2, 5, 8],
mlp_ratios: List[int] = [8, 8, 4, 4],
depths: List[int] = [2, 2, 2, 2],
) -> None:
super().__init__()
self.num_stages = len(embed_dims)
sr_ratios = [8, 4, 2, 1]
patch_sizes = [7, 3, 3, 3]
strides = [4, 2, 2, 2]
self.patch_embeds = nn.ModuleList()
self.blocks = nn.ModuleList()
self.norms = nn.ModuleList()
for i in range(self.num_stages):
in_ch = 3 if i == 0 else embed_dims[i - 1]
self.patch_embeds.append(
OverlapPatchEmbed(in_ch, embed_dims[i], patch_sizes[i], strides[i])
)
self.blocks.append(
nn.ModuleList([
TransformerBlock(embed_dims[i], num_heads[i], mlp_ratios[i], sr_ratios[i])
for _ in range(depths[i])
])
)
self.norms.append(nn.LayerNorm(embed_dims[i]))
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Extract hierarchical features.
Args:
x: Input image [B, 3, H, W].
Returns:
List of feature maps at each stage [B, C_i, H_i, W_i].
"""
features = []
for i in range(self.num_stages):
x, h, w = self.patch_embeds[i](x)
for blk in self.blocks[i]:
x = blk(x, h, w)
x = self.norms[i](x)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
features.append(x)
return features
class MLPDecoder(nn.Module):
"""MLP-based decoder that fuses multi-scale difference features.
Args:
embed_dims: Embedding dimensions from each encoder stage.
out_channels: Number of output channels (1 for binary change mask).
"""
def __init__(
self,
embed_dims: List[int] = [64, 128, 320, 512],
out_channels: int = 1,
) -> None:
super().__init__()
unified_dim = embed_dims[0]
self.linear_projections = nn.ModuleList([
nn.Conv2d(dim, unified_dim, kernel_size=1)
for dim in embed_dims
])
self.fuse = nn.Sequential(
nn.Conv2d(unified_dim * len(embed_dims), unified_dim, kernel_size=1),
nn.BatchNorm2d(unified_dim),
nn.ReLU(inplace=True),
)
self.head = nn.Conv2d(unified_dim, out_channels, kernel_size=1)
def forward(self, features: List[torch.Tensor], target_size: Tuple[int, int]) -> torch.Tensor:
"""Forward pass.
Args:
features: List of difference feature maps.
target_size: (H, W) of the desired output.
Returns:
Logits [B, 1, H, W].
"""
projected = []
for i, (feat, proj) in enumerate(zip(features, self.linear_projections)):
p = proj(feat)
p = F.interpolate(p, size=target_size, mode="bilinear", align_corners=False)
projected.append(p)
fused = self.fuse(torch.cat(projected, dim=1))
out = self.head(fused)
return out
class ChangeFormer(nn.Module):
"""ChangeFormer: Transformer-based Siamese network for change detection.
Args:
embed_dims: Embedding dims at each hierarchical stage.
num_heads: Attention heads at each stage.
mlp_ratios: MLP expansion ratios at each stage.
depths: Transformer block counts at each stage.
pretrained_backbone: Whether to load pretrained MiT weights.
"""
def __init__(
self,
embed_dims: List[int] = [64, 128, 320, 512],
num_heads: List[int] = [1, 2, 5, 8],
mlp_ratios: List[int] = [8, 8, 4, 4],
depths: List[int] = [2, 2, 2, 2],
pretrained_backbone: bool = True,
) -> None:
super().__init__()
# Shared Siamese encoder
self.encoder = MiTEncoder(embed_dims, num_heads, mlp_ratios, depths)
# MLP decoder
self.decoder = MLPDecoder(embed_dims, out_channels=1)
# TODO: Load pretrained MiT-B1 weights if pretrained_backbone is True
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x1: Before image [B, 3, 256, 256].
x2: After image [B, 3, 256, 256].
Returns:
Raw logits [B, 1, 256, 256].
"""
# Extract hierarchical features
feats_1 = self.encoder(x1)
feats_2 = self.encoder(x2)
# Compute difference at each scale
diff_feats = [torch.abs(f1 - f2) for f1, f2 in zip(feats_1, feats_2)]
# Decode to change mask
target_size = (x1.shape[2], x1.shape[3])
out = self.decoder(diff_feats, target_size)
return out
if __name__ == "__main__":
# Quick sanity check
model = ChangeFormer(pretrained_backbone=False)
x1 = torch.randn(1, 3, 256, 256)
x2 = torch.randn(1, 3, 256, 256)
out = model(x1, x2)
print(f"Input: {x1.shape}, Output: {out.shape}")
assert out.shape == (1, 1, 256, 256), f"Unexpected shape: {out.shape}"
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")