"""WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA. Usage: from pipeline import WoundNetB7Pipeline pipe = WoundNetB7Pipeline("models/") result = pipe.analyze("path/to/dfu_image.png") print(result) """ import torch import numpy as np import cv2 from pathlib import Path from dataclasses import dataclass, field, asdict from typing import Optional from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES from src.integrated_report import render_integrated_report @dataclass class AnalysisResult: """Complete DFU analysis result from a single image.""" # Segmentation class_distribution: dict = field(default_factory=dict) # {class_name: percentage} classmap: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8 probs: Optional[np.ndarray] = field(default=None, repr=False) # (4,H,W) float32 ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8 # Fitzpatrick fitzpatrick: Optional[FitzpatrickResult] = None # PWAT pwat: Optional[PWATResult] = None # Metadata image_size: tuple = (0, 0) device: str = "cpu" def summary(self) -> str: lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50] lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}") lines.append(f"Device: {self.device}") lines.append("\n--- Segmentation ---") for cls, pct in self.class_distribution.items(): lines.append(f" {cls:<15s}: {pct:5.1f}%") if self.fitzpatrick: f = self.fitzpatrick lines.append("\n--- Fitzpatrick / ITA ---") lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})") lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}") lines.append(f" L* mean: {f.l_skin_mean:.1f}") lines.append(f" Confidence: {f.confidence:.2f}") lines.append(f" Pixels: {f.healthy_pixels:,}") if self.pwat and self.pwat.scores_raw: p = self.pwat lines.append("\n--- PWAT Scores (Items 3-8) ---") lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}") lines.append(" " + "-" * 33) for item in [3, 4, 5, 6, 7, 8]: name = ITEM_NAMES.get(item, f"Item {item}") raw = p.scores_raw.get(item, "-") adj = p.scores_adjusted.get(item, "-") lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}") lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}") lines.append(f" Fitzpatrick correction applied for type: {p.fitzpatrick_type}") return "\n".join(lines) def to_dict(self) -> dict: d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution} if self.fitzpatrick: d["fitzpatrick"] = asdict(self.fitzpatrick) if self.pwat: d["pwat"] = { "scores_raw": self.pwat.scores_raw, "scores_adjusted": self.pwat.scores_adjusted, "total_raw": self.pwat.total_raw, "total_adjusted": self.pwat.total_adjusted, } return d class WoundNetB7Pipeline: """End-to-end DFU analysis pipeline. Args: models_dir: Path to models/ directory containing: - segmentation/WoundNetB7_proposed_best.pt - pwat/xgb_pwat{3-8}.pkl device: "cuda" or "cpu" (auto-detected if None) use_tta: Use test-time augmentation for segmentation (slower but more accurate) """ def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True): self.models_dir = Path(models_dir) self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.use_tta = use_tta # Load segmentation model seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt" self.seg_model = load_segmentation_model(str(seg_path), self.device) print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)") # Load PWAT predictor pwat_path = self.models_dir / "pwat" self.pwat_predictor = PWATPredictor(str(pwat_path)) print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)") print(f"Pipeline ready on {self.device}") def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult: """Analyze a DFU image end-to-end. Args: image_input: file path (str/Path), BGR numpy array, or RGB numpy array use_tta: Override TTA setting for this call Returns: AnalysisResult with segmentation, Fitzpatrick, and PWAT data. """ tta = use_tta if use_tta is not None else self.use_tta # Load image if isinstance(image_input, (str, Path)): img_bgr = cv2.imread(str(image_input)) if img_bgr is None: raise FileNotFoundError(f"Cannot read image: {image_input}") elif isinstance(image_input, np.ndarray): img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) else: raise TypeError(f"Unsupported input type: {type(image_input)}") h, w = img_bgr.shape[:2] # Step 1: Segmentation seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta) classmap = seg["classmap"] class_dist = {} for cid, name in CLASS_NAMES.items(): class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1) # Step 2: Fitzpatrick estimation (from healthy skin) fitz = estimate_fitzpatrick(img_bgr, seg["masks"]) # Step 3: PWAT prediction (from ulcer mask) ulcer_mask = (classmap == 3).astype(np.uint8) * 255 pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type) return AnalysisResult( class_distribution=class_dist, classmap=classmap, probs=seg["probs"], ulcer_mask=ulcer_mask, fitzpatrick=fitz, pwat=pwat, image_size=(h, w), device=str(self.device), ) def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: """Create binary ulcer segmentation overlay (ulcer only, red mask).""" img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) overlay = img_rgb.copy() if result.ulcer_mask is not None: ulcer_bool = result.ulcer_mask > 127 if np.any(ulcer_bool): overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6 overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8) contours, _ = cv2.findContours( ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2) return overlay_u8 return np.clip(overlay, 0, 255).astype(np.uint8) def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: """Create multiclass segmentation overlay using cached classmap.""" overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) if result.classmap is not None: for cid, color in CLASS_COLORS.items(): if cid == 0: continue mask = result.classmap == cid overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5 return overlay.astype(np.uint8) def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: """Create overlay visualization (backward compatible).""" return self.visualize_multiclass(img_bgr, result) def render_integrated_report(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: """Render a single-image integrated clinical dashboard (1920x1200 RGB). Combines segmentation, class distribution, Fitzpatrick/ITA estimation and PWAT scoring (raw + adjusted) into one nurse-facing report. """ original_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) binary_overlay = self.visualize_binary(img_bgr, result) multi_overlay = self.visualize_multiclass(img_bgr, result) return render_integrated_report(original_rgb, binary_overlay, multi_overlay, result)