"""WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer). Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder. Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]). """ import torch import torch.nn as nn import torch.nn.functional as F import segmentation_models_pytorch as smp import numpy as np import cv2 from pathlib import Path IMG_SIZE = 512 MEAN = np.array([0.485, 0.456, 0.406]) STD = np.array([0.229, 0.224, 0.225]) CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"} CLASS_COLORS = { 0: (0, 0, 0), 1: (0, 255, 0), 2: (255, 165, 0), 3: (255, 0, 0), } # --------------------------------------------------------------------------- # Architecture modules (match checkpoint weights exactly) # --------------------------------------------------------------------------- class ChannelAttention(nn.Module): def __init__(self, channels, reduction=16): super().__init__() self.mlp = nn.Sequential( nn.Linear(channels, channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channels // reduction, channels, bias=False), ) def forward(self, x): avg_out = self.mlp(x.mean(dim=[2, 3])) max_out = self.mlp(x.amax(dim=[2, 3])) attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1) return x * attn class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) def forward(self, x): avg_out = x.mean(dim=1, keepdim=True) max_out = x.amax(dim=1, keepdim=True) attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) return x * attn class CBAM(nn.Module): def __init__(self, channels, reduction=16, kernel_size=7): super().__init__() self.ca = ChannelAttention(channels, reduction) self.sa = SpatialAttention(kernel_size) def forward(self, x): return self.sa(self.ca(x)) class DifferentiableFractalDimension(nn.Module): def __init__(self, scales=None): super().__init__() self.scales = scales or [2, 4, 8, 16, 32] def forward(self, x): B, C, H, W = x.shape counts = [] for s in self.scales: if s >= H or s >= W: continue pooled = F.avg_pool2d(x, kernel_size=s, stride=s) n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3]) counts.append(n_boxes) if len(counts) < 2: return torch.ones(B, C, device=x.device) log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device)) log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1) n = log_s.shape[0] sx, sxx = log_s.sum(), (log_s**2).sum() sy = log_c.sum(dim=-1) sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1) slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8) return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) class DifferentiableEulerCharacteristic(nn.Module): def forward(self, x): B, C, H, W = x.shape b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5)) V = b.sum(dim=[2, 3]) E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3]) E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3]) F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3]) euler = V - E_h - E_v + F_val return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W) class TopologicalAttentionModule(nn.Module): def __init__(self, in_channels): super().__init__() self.fractal = DifferentiableFractalDimension() self.euler = DifferentiableEulerCharacteristic() self.alpha = nn.Parameter(torch.tensor(1.0)) self.beta = nn.Parameter(torch.tensor(1.0)) self.conv = nn.Sequential( nn.Conv2d(in_channels + 2, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, in_channels, 1), nn.Sigmoid(), ) def forward(self, x): B, C, H, W = x.shape fm = self.fractal(x).expand(B, 1, H, W) em = self.euler(x).expand(B, 1, H, W) attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1)) return x * attn + x class ASPP(nn.Module): def __init__(self, in_ch, out_ch, rates=None): super().__init__() rates = rates or [6, 12, 18] self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True)) self.atrous = nn.ModuleList( [nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates] ) self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True)) self.project = nn.Sequential( nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5) ) def forward(self, x): size = x.shape[2:] feats = [self.conv1x1(x)] + [a(x) for a in self.atrous] feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False)) return self.project(torch.cat(feats, dim=1)) class WoundNetB7(nn.Module): """WoundNetB7 matching the Track B checkpoint structure.""" NUM_CLASSES = 4 def __init__(self, num_classes=4): super().__init__() self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes) enc_ch = self.backbone.encoder.out_channels[-1] self.aspp = ASPP(enc_ch, enc_ch) self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2)) self.tam = TopologicalAttentionModule(num_classes) self.diffusion_weight = nn.Parameter(torch.tensor(0.01)) def forward(self, x): features = list(self.backbone.encoder(x)) features[-1] = self.aspp(features[-1]) try: dec = self.backbone.decoder(features) except TypeError: dec = self.backbone.decoder(*features) seg = self.backbone.segmentation_head(dec) seg = self.cbam(seg) seg = self.tam(seg) return seg # --------------------------------------------------------------------------- # Inference helpers # --------------------------------------------------------------------------- def preprocess(img_bgr: np.ndarray) -> torch.Tensor: """BGR image -> normalized CHW tensor (1, 3, 512, 512).""" img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) img = (img.astype(np.float32) / 255.0 - MEAN) / STD return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: """6-fold TTA -> averaged softmax probabilities (1, C, H, W).""" transforms = [ lambda x: x, lambda x: torch.flip(x, [3]), lambda x: torch.flip(x, [2]), lambda x: torch.rot90(x, 1, [2, 3]), lambda x: torch.rot90(x, 2, [2, 3]), lambda x: torch.rot90(x, 3, [2, 3]), ] inverse = [ lambda x: x, lambda x: torch.flip(x, [3]), lambda x: torch.flip(x, [2]), lambda x: torch.rot90(x, 3, [2, 3]), lambda x: torch.rot90(x, 2, [2, 3]), lambda x: torch.rot90(x, 1, [2, 3]), ] probs_sum = None with torch.no_grad(): for tfm, inv in zip(transforms, inverse): out = model(tfm(img_tensor).to(device)) if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, dict): out = out["seg"] p = inv(F.softmax(out, dim=1)) probs_sum = p if probs_sum is None else probs_sum + p return probs_sum / len(transforms) def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module: """Load WoundNetB7 from checkpoint.""" model = WoundNetB7(num_classes=4) state = torch.load(checkpoint_path, map_location=device, weights_only=False) # Remove PWAT head keys if present state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")} model.load_state_dict(state, strict=False) model.to(device).eval() return model def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict: """Run segmentation on a BGR image. Returns dict with: classmap: (H, W) uint8 with class indices 0-3 masks: dict of per-class binary masks {cls_name: (H, W) bool} probs: (4, H, W) float32 softmax probabilities """ h, w = img_bgr.shape[:2] tensor = preprocess(img_bgr) if use_tta: probs = tta_inference(model, tensor, device) else: with torch.no_grad(): out = model(tensor.to(device)) if isinstance(out, (tuple, list)): out = out[0] if isinstance(out, dict): out = out["seg"] probs = F.softmax(out, dim=1) probs_np = probs[0].cpu().numpy() probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)]) classmap = probs_resized.argmax(axis=0).astype(np.uint8) masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0} return {"classmap": classmap, "masks": masks, "probs": probs_resized} def postprocess_segmentation( classmap: np.ndarray, img_bgr: np.ndarray, min_foot_ratio: float = 0.01, dark_l_threshold: float = 15.0, ) -> np.ndarray: """Post-process segmentation with necrotic tissue recovery. Steps: 1. Keep only the largest connected component of foreground. 2. Exclude dark pixels NOT in the main connected component. 3. RECOVER necrotic tissue: dark regions adjacent to the detected foot that the model missed are reclassified as ulcer (class 3). 4. Light morphological closing to smooth edges (no opening — preserves thin structures like toes). """ h, w = classmap.shape cleaned = classmap.copy() # Step 1: Largest connected component of foreground foreground = (cleaned > 0).astype(np.uint8) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(foreground, connectivity=8) main_component_mask = np.zeros((h, w), dtype=bool) if num_labels > 1: areas = stats[1:, cv2.CC_STAT_AREA] largest_label = np.argmax(areas) + 1 main_component_mask = (labels == largest_label) min_area = h * w * min_foot_ratio for label_id in range(1, num_labels): if label_id == largest_label: continue if stats[label_id, cv2.CC_STAT_AREA] < min_area: cleaned[labels == label_id] = 0 else: main_component_mask = foreground.astype(bool) # Step 2: Dark pixel exclusion — ONLY for disconnected blobs lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32) l_channel = lab[:, :, 0] * (100.0 / 255.0) a_channel = lab[:, :, 1] - 128.0 s_channel = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)[:, :, 1] dark_mask = l_channel < dark_l_threshold is_foreground = cleaned > 0 dark_isolated = dark_mask & is_foreground & ~main_component_mask cleaned[dark_isolated] = 0 # Step 3: Necrotic tissue recovery # Dark skin-like regions adjacent to detected foot → reclassify as ulcer cleaned = recover_necrotic_tissue(cleaned, img_bgr, l_channel, a_channel, s_channel) # Step 4: Light morphological closing (fills small gaps, does NOT erode thin structures) kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) for cid in [1, 2, 3]: class_mask = (cleaned == cid).astype(np.uint8) if np.sum(class_mask) < 50: continue closed = cv2.morphologyEx(class_mask, cv2.MORPH_CLOSE, kernel_close) # Only ADD pixels (fill gaps), never remove new_pixels = (closed > 0) & (class_mask == 0) & (cleaned == 0) cleaned[new_pixels] = cid return cleaned def recover_necrotic_tissue( classmap: np.ndarray, img_bgr: np.ndarray, l_channel: np.ndarray, a_channel: np.ndarray, s_channel: np.ndarray, necrotic_l_max: float = 45.0, necrotic_s_max: float = 120.0, min_region_px: int = 100, ) -> np.ndarray: """Recover dark necrotic tissue regions adjacent to detected foreground. Necrotic tissue (eschar, gangrene, dry/wet gangrene on toes) is very dark and the model often misclassifies it as background. This function uses iterative dilation to progressively recover necrotic regions connected to the foot, even when there's a gap between the detected foot and the toes. Detection criteria for necrotic candidate pixels: - L* < 45 (dark tissue — covers eschar, gangrene, necrotic toes) - Saturation < 120 (not vivid colored — rules out green/blue backgrounds) - Currently classified as background (class 0) Iterative approach: dilate foreground progressively (3 rounds x 30px), recovering necrotic candidates at each step. This bridges gaps between the detected foot and disconnected necrotic regions like toes. """ h, w = classmap.shape recovered = classmap.copy() # Candidate necrotic pixels: dark, not vivid, currently background is_background = recovered == 0 necrotic_candidates = ( is_background & (l_channel < necrotic_l_max) & (s_channel < necrotic_s_max) ) if not np.any(necrotic_candidates): return recovered # Iterative recovery: progressively expand from detected foreground # Each round dilates 30px and recovers adjacent necrotic tissue, # then the recovered tissue becomes part of the foreground for the next round. # 3 rounds × 30px = up to 90px reach from the original foreground edge. dilation_step = 30 num_rounds = 3 current_foreground = (recovered > 0).astype(np.uint8) for round_idx in range(num_rounds): kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_step, dilation_step)) fg_dilated = cv2.dilate(current_foreground, kernel).astype(bool) # Candidates that are within reach this round adjacent = necrotic_candidates & fg_dilated & (recovered == 0) if not np.any(adjacent): break # Connected component filtering adjacent_u8 = adjacent.astype(np.uint8) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(adjacent_u8, connectivity=8) recovered_any = False for label_id in range(1, num_labels): area = stats[label_id, cv2.CC_STAT_AREA] if area < min_region_px: continue region_mask = labels == label_id # Verify it touches current foreground region_dilated = cv2.dilate( region_mask.astype(np.uint8), cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) ) if np.any((region_dilated > 0) & (current_foreground > 0)): recovered[region_mask] = 3 # Ulcer (necrotic) recovered_any = True if not recovered_any: break # Update foreground for next round (include newly recovered tissue) current_foreground = (recovered > 0).astype(np.uint8) return recovered