Spaces:
Sleeping
Sleeping
| """WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer). | |
| Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder. | |
| Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]). | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import segmentation_models_pytorch as smp | |
| import numpy as np | |
| import cv2 | |
| from pathlib import Path | |
| IMG_SIZE = 512 | |
| MEAN = np.array([0.485, 0.456, 0.406]) | |
| STD = np.array([0.229, 0.224, 0.225]) | |
| CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"} | |
| CLASS_COLORS = { | |
| 0: (0, 0, 0), | |
| 1: (0, 255, 0), | |
| 2: (255, 165, 0), | |
| 3: (255, 0, 0), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Architecture modules (match checkpoint weights exactly) | |
| # --------------------------------------------------------------------------- | |
| class ChannelAttention(nn.Module): | |
| def __init__(self, channels, reduction=16): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(channels, channels // reduction, bias=False), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(channels // reduction, channels, bias=False), | |
| ) | |
| def forward(self, x): | |
| avg_out = self.mlp(x.mean(dim=[2, 3])) | |
| max_out = self.mlp(x.amax(dim=[2, 3])) | |
| attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1) | |
| return x * attn | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, kernel_size=7): | |
| super().__init__() | |
| self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) | |
| def forward(self, x): | |
| avg_out = x.mean(dim=1, keepdim=True) | |
| max_out = x.amax(dim=1, keepdim=True) | |
| attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1))) | |
| return x * attn | |
| class CBAM(nn.Module): | |
| def __init__(self, channels, reduction=16, kernel_size=7): | |
| super().__init__() | |
| self.ca = ChannelAttention(channels, reduction) | |
| self.sa = SpatialAttention(kernel_size) | |
| def forward(self, x): | |
| return self.sa(self.ca(x)) | |
| class DifferentiableFractalDimension(nn.Module): | |
| def __init__(self, scales=None): | |
| super().__init__() | |
| self.scales = scales or [2, 4, 8, 16, 32] | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| counts = [] | |
| for s in self.scales: | |
| if s >= H or s >= W: | |
| continue | |
| pooled = F.avg_pool2d(x, kernel_size=s, stride=s) | |
| n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3]) | |
| counts.append(n_boxes) | |
| if len(counts) < 2: | |
| return torch.ones(B, C, device=x.device) | |
| log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device)) | |
| log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1) | |
| n = log_s.shape[0] | |
| sx, sxx = log_s.sum(), (log_s**2).sum() | |
| sy = log_c.sum(dim=-1) | |
| sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1) | |
| slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8) | |
| return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) | |
| class DifferentiableEulerCharacteristic(nn.Module): | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5)) | |
| V = b.sum(dim=[2, 3]) | |
| E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3]) | |
| E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3]) | |
| F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3]) | |
| euler = V - E_h - E_v + F_val | |
| return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W) | |
| class TopologicalAttentionModule(nn.Module): | |
| def __init__(self, in_channels): | |
| super().__init__() | |
| self.fractal = DifferentiableFractalDimension() | |
| self.euler = DifferentiableEulerCharacteristic() | |
| self.alpha = nn.Parameter(torch.tensor(1.0)) | |
| self.beta = nn.Parameter(torch.tensor(1.0)) | |
| self.conv = nn.Sequential( | |
| nn.Conv2d(in_channels + 2, in_channels, 1), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels, in_channels, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| fm = self.fractal(x).expand(B, 1, H, W) | |
| em = self.euler(x).expand(B, 1, H, W) | |
| attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1)) | |
| return x * attn + x | |
| class ASPP(nn.Module): | |
| def __init__(self, in_ch, out_ch, rates=None): | |
| super().__init__() | |
| rates = rates or [6, 12, 18] | |
| self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True)) | |
| self.atrous = nn.ModuleList( | |
| [nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates] | |
| ) | |
| self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True)) | |
| self.project = nn.Sequential( | |
| nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5) | |
| ) | |
| def forward(self, x): | |
| size = x.shape[2:] | |
| feats = [self.conv1x1(x)] + [a(x) for a in self.atrous] | |
| feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False)) | |
| return self.project(torch.cat(feats, dim=1)) | |
| class WoundNetB7(nn.Module): | |
| """WoundNetB7 matching the Track B checkpoint structure.""" | |
| NUM_CLASSES = 4 | |
| def __init__(self, num_classes=4): | |
| super().__init__() | |
| self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes) | |
| enc_ch = self.backbone.encoder.out_channels[-1] | |
| self.aspp = ASPP(enc_ch, enc_ch) | |
| self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2)) | |
| self.tam = TopologicalAttentionModule(num_classes) | |
| self.diffusion_weight = nn.Parameter(torch.tensor(0.01)) | |
| def forward(self, x): | |
| features = list(self.backbone.encoder(x)) | |
| features[-1] = self.aspp(features[-1]) | |
| try: | |
| dec = self.backbone.decoder(features) | |
| except TypeError: | |
| dec = self.backbone.decoder(*features) | |
| seg = self.backbone.segmentation_head(dec) | |
| seg = self.cbam(seg) | |
| seg = self.tam(seg) | |
| return seg | |
| # --------------------------------------------------------------------------- | |
| # Inference helpers | |
| # --------------------------------------------------------------------------- | |
| def preprocess(img_bgr: np.ndarray) -> torch.Tensor: | |
| """BGR image -> normalized CHW tensor (1, 3, 512, 512).""" | |
| img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) | |
| img = (img.astype(np.float32) / 255.0 - MEAN) / STD | |
| return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() | |
| def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor: | |
| """6-fold TTA -> averaged softmax probabilities (1, C, H, W).""" | |
| transforms = [ | |
| lambda x: x, | |
| lambda x: torch.flip(x, [3]), | |
| lambda x: torch.flip(x, [2]), | |
| lambda x: torch.rot90(x, 1, [2, 3]), | |
| lambda x: torch.rot90(x, 2, [2, 3]), | |
| lambda x: torch.rot90(x, 3, [2, 3]), | |
| ] | |
| inverse = [ | |
| lambda x: x, | |
| lambda x: torch.flip(x, [3]), | |
| lambda x: torch.flip(x, [2]), | |
| lambda x: torch.rot90(x, 3, [2, 3]), | |
| lambda x: torch.rot90(x, 2, [2, 3]), | |
| lambda x: torch.rot90(x, 1, [2, 3]), | |
| ] | |
| probs_sum = None | |
| with torch.no_grad(): | |
| for tfm, inv in zip(transforms, inverse): | |
| out = model(tfm(img_tensor).to(device)) | |
| if isinstance(out, (tuple, list)): | |
| out = out[0] | |
| if isinstance(out, dict): | |
| out = out["seg"] | |
| p = inv(F.softmax(out, dim=1)) | |
| probs_sum = p if probs_sum is None else probs_sum + p | |
| return probs_sum / len(transforms) | |
| def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module: | |
| """Load WoundNetB7 from checkpoint.""" | |
| model = WoundNetB7(num_classes=4) | |
| state = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| # Remove PWAT head keys if present | |
| state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")} | |
| model.load_state_dict(state, strict=False) | |
| model.to(device).eval() | |
| return model | |
| def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict: | |
| """Run segmentation on a BGR image. | |
| Returns dict with: | |
| classmap: (H, W) uint8 with class indices 0-3 | |
| masks: dict of per-class binary masks {cls_name: (H, W) bool} | |
| probs: (4, H, W) float32 softmax probabilities | |
| """ | |
| h, w = img_bgr.shape[:2] | |
| tensor = preprocess(img_bgr) | |
| if use_tta: | |
| probs = tta_inference(model, tensor, device) | |
| else: | |
| with torch.no_grad(): | |
| out = model(tensor.to(device)) | |
| if isinstance(out, (tuple, list)): | |
| out = out[0] | |
| if isinstance(out, dict): | |
| out = out["seg"] | |
| probs = F.softmax(out, dim=1) | |
| probs_np = probs[0].cpu().numpy() | |
| probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)]) | |
| classmap = probs_resized.argmax(axis=0).astype(np.uint8) | |
| masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0} | |
| return {"classmap": classmap, "masks": masks, "probs": probs_resized} | |