"""WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer). Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder. Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]). """ import torch import torch.nn as nn import torch.nn.functional as F import segmentation_models_pytorch as smp import numpy as np import cv2 from pathlib import Path IMG_SIZE = 512 MEAN = np.array([0.485, 0.456, 0.406]) STD = np.array([0.229, 0.224, 0.225]) CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"} CLASS_COLORS = { 0: (0, 0, 0), 1: (0, 255, 0), 2: (255, 165, 0), 3: (255, 0, 0), } # --------------------------------------------------------------------------- # Architecture modules (match checkpoint weights exactly) # --------------------------------------------------------------------------- class ChannelAttention(nn.Module): def __init__(self, channels, reduction=16): super().__init__() self.mlp = nn.Sequential( nn.Linear(channels, channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channels // reduction, channels, bias=False), ) def forward(self, x): avg_out = self.mlp(x.mean(dim=[2, 3])) max_out = self.mlp(x.amax(dim=[2, 3])) attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1) return x * attn class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x): avg_out = x.mean(dim=1, keepdim=True) max_out = x.amax(dim=1, keepdim=True) attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) return x * attn class CBAM(nn.Module): def __init__(self, channels, reduction=16, kernel_size=7): super().__init__() self.ca = ChannelAttention(channels, reduction) self.sa = SpatialAttention(kernel_size) def forward(self, x): return self.sa(self.ca(x)) class DifferentiableFractalDimension(nn.Module): def __init__(self, scales=None): super().__init__() self.scales = scales or [2, 4, 8, 16, 32] def forward(self, x): B, C, H, W = x.shape counts = [] for s in self.scales: if s >= H or s >= W: continue pooled = F.avg_pool2d(x, kernel_size=s, stride=s) n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3]) counts.append(n_boxes) if len(counts) < 2: return torch.ones(B, C, device=x.device) log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device)) log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1) n = log_s.shape[0] sx, sxx = log_s.sum(), (log_s**2).sum() sy = log_c.sum(dim=-1) sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1) slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8) return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) class DifferentiableEulerCharacteristic(nn.Module): def forward(self, x): B, C, H, W = x.shape b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5)) V = b.sum(dim=[2, 3]) E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3]) E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3]) F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3]) euler = V - E_h - E_v + F_val return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W) class TopologicalAttentionModule(nn.Module): def __init__(self, in_channels): super().__init__() self.fractal = DifferentiableFractalDimension() self.euler = DifferentiableEulerCharacteristic() self.alpha = nn.Parameter(torch.tensor(1.0)) self.beta = nn.Parameter(torch.tensor(1.0)) self.conv = nn.Sequential( nn.Conv2d(in_channels + 2, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, in_channels, 1), nn.Sigmoid(), ) def forward(self, x): B, C, H, W = x.shape fm = self.fractal(x).expand(B, 1, H, W) em = self.euler(x).expand(B, 1, H, W) attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1)) return x * attn + x class ASPP(nn.Module): def __init__(self, in_ch, out_ch, rates=None): super().__init__() rates = rates or [6, 12, 18] self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True)) self.atrous = nn.ModuleList( [nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates] ) self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True)) self.project = nn.Sequential( nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5) ) def forward(self, x): size = x.shape[2:] feats = [self.conv1x1(x)] + [a(x) for a in self.atrous] feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False)) return self.project(torch.cat(feats, dim=1)) class WoundNetB7(nn.Module): """WoundNetB7 matching the Track B checkpoint structure.""" NUM_CLASSES = 4 def __init__(self, num_classes=4): super().__init__() self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes) enc_ch = self.backbone.encoder.out_channels[-1] self.aspp = ASPP(enc_ch, enc_ch) self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2)) self.tam = TopologicalAttentionModule(num_classes) self.diffusion_weight = nn.Parameter(torch.tensor(0.01)) def forward(self, x): features = list(self.backbone.encoder(x)) features[-1] = self.aspp(features[-1]) try: dec = self.backbone.decoder(features) except TypeError: dec = self.backbone.decoder(*features) seg = self.backbone.segmentation_head(dec) seg = self.cbam(seg) seg = self.tam(seg) return seg # --------------------------------------------------------------------------- # Inference helpers # --------------------------------------------------------------------------- def preprocess(img_bgr: np.ndarray) -> torch.Tensor: """BGR image -> normalized CHW tensor (1, 3, 512, 512).""" img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) img = (img.astype(np.float32) / 255.0 - MEAN) / STD return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: """6-fold TTA -> averaged softmax probabilities (1, C, H, W).""" transforms = [ lambda x: x, lambda x: torch.flip(x, [3]), lambda x: torch.flip(x, [2]), lambda x: torch.rot90(x, 1, [2, 3]), lambda x: torch.rot90(x, 2, [2, 3]), lambda x: torch.rot90(x, 3, [2, 3]), ] inverse = [ lambda x: x, lambda x: torch.flip(x, [3]), lambda x: torch.flip(x, [2]), lambda x: torch.rot90(x, 3, [2, 3]), lambda x: torch.rot90(x, 2, [2, 3]), lambda x: torch.rot90(x, 1, [2, 3]), ] probs_sum = None with torch.no_grad(): for tfm, inv in zip(transforms, inverse): out = model(tfm(img_tensor).to(device)) if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, dict): out = out["seg"] p = inv(F.softmax(out, dim=1)) probs_sum = p if probs_sum is None else probs_sum + p return probs_sum / len(transforms) def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module: """Load WoundNetB7 from checkpoint.""" model = WoundNetB7(num_classes=4) state = torch.load(checkpoint_path, map_location=device, weights_only=False) # Remove PWAT head keys if present state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")} model.load_state_dict(state, strict=False) model.to(device).eval() return model def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict: """Run segmentation on a BGR image. Returns dict with: classmap: (H, W) uint8 with class indices 0-3 masks: dict of per-class binary masks {cls_name: (H, W) bool} probs: (4, H, W) float32 softmax probabilities """ h, w = img_bgr.shape[:2] tensor = preprocess(img_bgr) if use_tta: probs = tta_inference(model, tensor, device) else: with torch.no_grad(): out = model(tensor.to(device)) if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, dict): out = out["seg"] probs = F.softmax(out, dim=1) probs_np = probs[0].cpu().numpy() probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)]) classmap = probs_resized.argmax(axis=0).astype(np.uint8) masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0} return {"classmap": classmap, "masks": masks, "probs": probs_resized}