mmarquezsa's picture
Wire integrated dashboard renderer (pipeline.py)
7b081d1 verified
"""WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA.
Usage:
from pipeline import WoundNetB7Pipeline
pipe = WoundNetB7Pipeline("models/")
result = pipe.analyze("path/to/dfu_image.png")
print(result)
"""
import torch
import numpy as np
import cv2
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional
from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS
from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult
from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES
from src.integrated_report import render_integrated_report
@dataclass
class AnalysisResult:
"""Complete DFU analysis result from a single image."""
# Segmentation
class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
classmap: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
probs: Optional[np.ndarray] = field(default=None, repr=False) # (4,H,W) float32
ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
# Fitzpatrick
fitzpatrick: Optional[FitzpatrickResult] = None
# PWAT
pwat: Optional[PWATResult] = None
# Metadata
image_size: tuple = (0, 0)
device: str = "cpu"
def summary(self) -> str:
lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50]
lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}")
lines.append(f"Device: {self.device}")
lines.append("\n--- Segmentation ---")
for cls, pct in self.class_distribution.items():
lines.append(f" {cls:<15s}: {pct:5.1f}%")
if self.fitzpatrick:
f = self.fitzpatrick
lines.append("\n--- Fitzpatrick / ITA ---")
lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})")
lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}")
lines.append(f" L* mean: {f.l_skin_mean:.1f}")
lines.append(f" Confidence: {f.confidence:.2f}")
lines.append(f" Pixels: {f.healthy_pixels:,}")
if self.pwat and self.pwat.scores_raw:
p = self.pwat
lines.append("\n--- PWAT Scores (Items 3-8) ---")
lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}")
lines.append(" " + "-" * 33)
for item in [3, 4, 5, 6, 7, 8]:
name = ITEM_NAMES.get(item, f"Item {item}")
raw = p.scores_raw.get(item, "-")
adj = p.scores_adjusted.get(item, "-")
lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}")
lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}")
lines.append(f" Fitzpatrick correction applied for type: {p.fitzpatrick_type}")
return "\n".join(lines)
def to_dict(self) -> dict:
d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution}
if self.fitzpatrick:
d["fitzpatrick"] = asdict(self.fitzpatrick)
if self.pwat:
d["pwat"] = {
"scores_raw": self.pwat.scores_raw,
"scores_adjusted": self.pwat.scores_adjusted,
"total_raw": self.pwat.total_raw,
"total_adjusted": self.pwat.total_adjusted,
}
return d
class WoundNetB7Pipeline:
"""End-to-end DFU analysis pipeline.
Args:
models_dir: Path to models/ directory containing:
- segmentation/WoundNetB7_proposed_best.pt
- pwat/xgb_pwat{3-8}.pkl
device: "cuda" or "cpu" (auto-detected if None)
use_tta: Use test-time augmentation for segmentation (slower but more accurate)
"""
def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True):
self.models_dir = Path(models_dir)
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.use_tta = use_tta
# Load segmentation model
seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt"
self.seg_model = load_segmentation_model(str(seg_path), self.device)
print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)")
# Load PWAT predictor
pwat_path = self.models_dir / "pwat"
self.pwat_predictor = PWATPredictor(str(pwat_path))
print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)")
print(f"Pipeline ready on {self.device}")
def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult:
"""Analyze a DFU image end-to-end.
Args:
image_input: file path (str/Path), BGR numpy array, or RGB numpy array
use_tta: Override TTA setting for this call
Returns:
AnalysisResult with segmentation, Fitzpatrick, and PWAT data.
"""
tta = use_tta if use_tta is not None else self.use_tta
# Load image
if isinstance(image_input, (str, Path)):
img_bgr = cv2.imread(str(image_input))
if img_bgr is None:
raise FileNotFoundError(f"Cannot read image: {image_input}")
elif isinstance(image_input, np.ndarray):
img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
else:
raise TypeError(f"Unsupported input type: {type(image_input)}")
h, w = img_bgr.shape[:2]
# Step 1: Segmentation
seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta)
classmap = seg["classmap"]
class_dist = {}
for cid, name in CLASS_NAMES.items():
class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1)
# Step 2: Fitzpatrick estimation (from healthy skin)
fitz = estimate_fitzpatrick(img_bgr, seg["masks"])
# Step 3: PWAT prediction (from ulcer mask)
ulcer_mask = (classmap == 3).astype(np.uint8) * 255
pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type)
return AnalysisResult(
class_distribution=class_dist,
classmap=classmap,
probs=seg["probs"],
ulcer_mask=ulcer_mask,
fitzpatrick=fitz,
pwat=pwat,
image_size=(h, w),
device=str(self.device),
)
def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create binary ulcer segmentation overlay (ulcer only, red mask)."""
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
overlay = img_rgb.copy()
if result.ulcer_mask is not None:
ulcer_bool = result.ulcer_mask > 127
if np.any(ulcer_bool):
overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6
overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8)
contours, _ = cv2.findContours(
ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2)
return overlay_u8
return np.clip(overlay, 0, 255).astype(np.uint8)
def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create multiclass segmentation overlay using cached classmap."""
overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
if result.classmap is not None:
for cid, color in CLASS_COLORS.items():
if cid == 0:
continue
mask = result.classmap == cid
overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
return overlay.astype(np.uint8)
def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create overlay visualization (backward compatible)."""
return self.visualize_multiclass(img_bgr, result)
def render_integrated_report(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Render a single-image integrated clinical dashboard (1920x1200 RGB).
Combines segmentation, class distribution, Fitzpatrick/ITA estimation
and PWAT scoring (raw + adjusted) into one nurse-facing report.
"""
original_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
binary_overlay = self.visualize_binary(img_bgr, result)
multi_overlay = self.visualize_multiclass(img_bgr, result)
return render_integrated_report(original_rgb, binary_overlay, multi_overlay, result)