WoundNetB7-DFU-Analysis / src /segmentation.py
mmarquezsa's picture
Add pipeline code, PWAT models, and Gradio app
21ccfaf verified
raw
history blame
9.84 kB
"""WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer).
Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder.
Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import numpy as np
import cv2
from pathlib import Path
IMG_SIZE = 512
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])
CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"}
CLASS_COLORS = {
0: (0, 0, 0),
1: (0, 255, 0),
2: (255, 165, 0),
3: (255, 0, 0),
}
# ---------------------------------------------------------------------------
# Architecture modules (match checkpoint weights exactly)
# ---------------------------------------------------------------------------
class ChannelAttention(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
)
def forward(self, x):
avg_out = self.mlp(x.mean(dim=[2, 3]))
max_out = self.mlp(x.amax(dim=[2, 3]))
attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
return x * attn
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
def forward(self, x):
avg_out = x.mean(dim=1, keepdim=True)
max_out = x.amax(dim=1, keepdim=True)
attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
return x * attn
class CBAM(nn.Module):
def __init__(self, channels, reduction=16, kernel_size=7):
super().__init__()
self.ca = ChannelAttention(channels, reduction)
self.sa = SpatialAttention(kernel_size)
def forward(self, x):
return self.sa(self.ca(x))
class DifferentiableFractalDimension(nn.Module):
def __init__(self, scales=None):
super().__init__()
self.scales = scales or [2, 4, 8, 16, 32]
def forward(self, x):
B, C, H, W = x.shape
counts = []
for s in self.scales:
if s >= H or s >= W:
continue
pooled = F.avg_pool2d(x, kernel_size=s, stride=s)
n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3])
counts.append(n_boxes)
if len(counts) < 2:
return torch.ones(B, C, device=x.device)
log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device))
log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1)
n = log_s.shape[0]
sx, sxx = log_s.sum(), (log_s**2).sum()
sy = log_c.sum(dim=-1)
sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8)
return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1)
class DifferentiableEulerCharacteristic(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5))
V = b.sum(dim=[2, 3])
E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3])
E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3])
F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3])
euler = V - E_h - E_v + F_val
return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W)
class TopologicalAttentionModule(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.fractal = DifferentiableFractalDimension()
self.euler = DifferentiableEulerCharacteristic()
self.alpha = nn.Parameter(torch.tensor(1.0))
self.beta = nn.Parameter(torch.tensor(1.0))
self.conv = nn.Sequential(
nn.Conv2d(in_channels + 2, in_channels, 1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, in_channels, 1),
nn.Sigmoid(),
)
def forward(self, x):
B, C, H, W = x.shape
fm = self.fractal(x).expand(B, 1, H, W)
em = self.euler(x).expand(B, 1, H, W)
attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1))
return x * attn + x
class ASPP(nn.Module):
def __init__(self, in_ch, out_ch, rates=None):
super().__init__()
rates = rates or [6, 12, 18]
self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True))
self.atrous = nn.ModuleList(
[nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates]
)
self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True))
self.project = nn.Sequential(
nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5)
)
def forward(self, x):
size = x.shape[2:]
feats = [self.conv1x1(x)] + [a(x) for a in self.atrous]
feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False))
return self.project(torch.cat(feats, dim=1))
class WoundNetB7(nn.Module):
"""WoundNetB7 matching the Track B checkpoint structure."""
NUM_CLASSES = 4
def __init__(self, num_classes=4):
super().__init__()
self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes)
enc_ch = self.backbone.encoder.out_channels[-1]
self.aspp = ASPP(enc_ch, enc_ch)
self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2))
self.tam = TopologicalAttentionModule(num_classes)
self.diffusion_weight = nn.Parameter(torch.tensor(0.01))
def forward(self, x):
features = list(self.backbone.encoder(x))
features[-1] = self.aspp(features[-1])
try:
dec = self.backbone.decoder(features)
except TypeError:
dec = self.backbone.decoder(*features)
seg = self.backbone.segmentation_head(dec)
seg = self.cbam(seg)
seg = self.tam(seg)
return seg
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def preprocess(img_bgr: np.ndarray) -> torch.Tensor:
"""BGR image -> normalized CHW tensor (1, 3, 512, 512)."""
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
img = (img.astype(np.float32) / 255.0 - MEAN) / STD
return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float()
def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
"""6-fold TTA -> averaged softmax probabilities (1, C, H, W)."""
transforms = [
lambda x: x,
lambda x: torch.flip(x, [3]),
lambda x: torch.flip(x, [2]),
lambda x: torch.rot90(x, 1, [2, 3]),
lambda x: torch.rot90(x, 2, [2, 3]),
lambda x: torch.rot90(x, 3, [2, 3]),
]
inverse = [
lambda x: x,
lambda x: torch.flip(x, [3]),
lambda x: torch.flip(x, [2]),
lambda x: torch.rot90(x, 3, [2, 3]),
lambda x: torch.rot90(x, 2, [2, 3]),
lambda x: torch.rot90(x, 1, [2, 3]),
]
probs_sum = None
with torch.no_grad():
for tfm, inv in zip(transforms, inverse):
out = model(tfm(img_tensor).to(device))
if isinstance(out, (tuple, list)):
out = out[0]
if isinstance(out, dict):
out = out["seg"]
p = inv(F.softmax(out, dim=1))
probs_sum = p if probs_sum is None else probs_sum + p
return probs_sum / len(transforms)
def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module:
"""Load WoundNetB7 from checkpoint."""
model = WoundNetB7(num_classes=4)
state = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Remove PWAT head keys if present
state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")}
model.load_state_dict(state, strict=False)
model.to(device).eval()
return model
def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict:
"""Run segmentation on a BGR image.
Returns dict with:
classmap: (H, W) uint8 with class indices 0-3
masks: dict of per-class binary masks {cls_name: (H, W) bool}
probs: (4, H, W) float32 softmax probabilities
"""
h, w = img_bgr.shape[:2]
tensor = preprocess(img_bgr)
if use_tta:
probs = tta_inference(model, tensor, device)
else:
with torch.no_grad():
out = model(tensor.to(device))
if isinstance(out, (tuple, list)):
out = out[0]
if isinstance(out, dict):
out = out["seg"]
probs = F.softmax(out, dim=1)
probs_np = probs[0].cpu().numpy()
probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
classmap = probs_resized.argmax(axis=0).astype(np.uint8)
masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
return {"classmap": classmap, "masks": masks, "probs": probs_resized}