| """WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA. |
| |
| Usage: |
| from pipeline import WoundNetB7Pipeline |
| pipe = WoundNetB7Pipeline("models/") |
| result = pipe.analyze("path/to/dfu_image.png") |
| print(result) |
| """ |
| import torch |
| import numpy as np |
| import cv2 |
| from pathlib import Path |
| from dataclasses import dataclass, field, asdict |
| from typing import Optional |
|
|
| from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS |
| from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult |
| from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES |
| from src.integrated_report import render_integrated_report |
|
|
|
|
| @dataclass |
| class AnalysisResult: |
| """Complete DFU analysis result from a single image.""" |
| |
| class_distribution: dict = field(default_factory=dict) |
| classmap: Optional[np.ndarray] = field(default=None, repr=False) |
| probs: Optional[np.ndarray] = field(default=None, repr=False) |
| ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False) |
| |
| fitzpatrick: Optional[FitzpatrickResult] = None |
| |
| pwat: Optional[PWATResult] = None |
| |
| image_size: tuple = (0, 0) |
| device: str = "cpu" |
|
|
| def summary(self) -> str: |
| lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50] |
| lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}") |
| lines.append(f"Device: {self.device}") |
|
|
| lines.append("\n--- Segmentation ---") |
| for cls, pct in self.class_distribution.items(): |
| lines.append(f" {cls:<15s}: {pct:5.1f}%") |
|
|
| if self.fitzpatrick: |
| f = self.fitzpatrick |
| lines.append("\n--- Fitzpatrick / ITA ---") |
| lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})") |
| lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}") |
| lines.append(f" L* mean: {f.l_skin_mean:.1f}") |
| lines.append(f" Confidence: {f.confidence:.2f}") |
| lines.append(f" Pixels: {f.healthy_pixels:,}") |
|
|
| if self.pwat and self.pwat.scores_raw: |
| p = self.pwat |
| lines.append("\n--- PWAT Scores (Items 3-8) ---") |
| lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}") |
| lines.append(" " + "-" * 33) |
| for item in [3, 4, 5, 6, 7, 8]: |
| name = ITEM_NAMES.get(item, f"Item {item}") |
| raw = p.scores_raw.get(item, "-") |
| adj = p.scores_adjusted.get(item, "-") |
| lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}") |
| lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}") |
| lines.append(f" Fitzpatrick correction applied for type: {p.fitzpatrick_type}") |
|
|
| return "\n".join(lines) |
|
|
| def to_dict(self) -> dict: |
| d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution} |
| if self.fitzpatrick: |
| d["fitzpatrick"] = asdict(self.fitzpatrick) |
| if self.pwat: |
| d["pwat"] = { |
| "scores_raw": self.pwat.scores_raw, |
| "scores_adjusted": self.pwat.scores_adjusted, |
| "total_raw": self.pwat.total_raw, |
| "total_adjusted": self.pwat.total_adjusted, |
| } |
| return d |
|
|
|
|
| class WoundNetB7Pipeline: |
| """End-to-end DFU analysis pipeline. |
| |
| Args: |
| models_dir: Path to models/ directory containing: |
| - segmentation/WoundNetB7_proposed_best.pt |
| - pwat/xgb_pwat{3-8}.pkl |
| device: "cuda" or "cpu" (auto-detected if None) |
| use_tta: Use test-time augmentation for segmentation (slower but more accurate) |
| """ |
|
|
| def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True): |
| self.models_dir = Path(models_dir) |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| self.use_tta = use_tta |
|
|
| |
| seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt" |
| self.seg_model = load_segmentation_model(str(seg_path), self.device) |
| print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)") |
|
|
| |
| pwat_path = self.models_dir / "pwat" |
| self.pwat_predictor = PWATPredictor(str(pwat_path)) |
| print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)") |
|
|
| print(f"Pipeline ready on {self.device}") |
|
|
| def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult: |
| """Analyze a DFU image end-to-end. |
| |
| Args: |
| image_input: file path (str/Path), BGR numpy array, or RGB numpy array |
| use_tta: Override TTA setting for this call |
| |
| Returns: |
| AnalysisResult with segmentation, Fitzpatrick, and PWAT data. |
| """ |
| tta = use_tta if use_tta is not None else self.use_tta |
|
|
| |
| if isinstance(image_input, (str, Path)): |
| img_bgr = cv2.imread(str(image_input)) |
| if img_bgr is None: |
| raise FileNotFoundError(f"Cannot read image: {image_input}") |
| elif isinstance(image_input, np.ndarray): |
| img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) |
| else: |
| raise TypeError(f"Unsupported input type: {type(image_input)}") |
|
|
| h, w = img_bgr.shape[:2] |
|
|
| |
| seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta) |
| classmap = seg["classmap"] |
|
|
| class_dist = {} |
| for cid, name in CLASS_NAMES.items(): |
| class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1) |
|
|
| |
| fitz = estimate_fitzpatrick(img_bgr, seg["masks"]) |
|
|
| |
| ulcer_mask = (classmap == 3).astype(np.uint8) * 255 |
| pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type) |
|
|
| return AnalysisResult( |
| class_distribution=class_dist, |
| classmap=classmap, |
| probs=seg["probs"], |
| ulcer_mask=ulcer_mask, |
| fitzpatrick=fitz, |
| pwat=pwat, |
| image_size=(h, w), |
| device=str(self.device), |
| ) |
|
|
| def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: |
| """Create binary ulcer segmentation overlay (ulcer only, red mask).""" |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) |
| overlay = img_rgb.copy() |
| if result.ulcer_mask is not None: |
| ulcer_bool = result.ulcer_mask > 127 |
| if np.any(ulcer_bool): |
| overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6 |
| overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8) |
| contours, _ = cv2.findContours( |
| ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
| ) |
| cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2) |
| return overlay_u8 |
| return np.clip(overlay, 0, 255).astype(np.uint8) |
|
|
| def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: |
| """Create multiclass segmentation overlay using cached classmap.""" |
| overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) |
| if result.classmap is not None: |
| for cid, color in CLASS_COLORS.items(): |
| if cid == 0: |
| continue |
| mask = result.classmap == cid |
| overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5 |
| return overlay.astype(np.uint8) |
|
|
| def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: |
| """Create overlay visualization (backward compatible).""" |
| return self.visualize_multiclass(img_bgr, result) |
|
|
| def render_integrated_report(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray: |
| """Render a single-image integrated clinical dashboard (1920x1200 RGB). |
| |
| Combines segmentation, class distribution, Fitzpatrick/ITA estimation |
| and PWAT scoring (raw + adjusted) into one nurse-facing report. |
| """ |
| original_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
| binary_overlay = self.visualize_binary(img_bgr, result) |
| multi_overlay = self.visualize_multiclass(img_bgr, result) |
| return render_integrated_report(original_rgb, binary_overlay, multi_overlay, result) |
|
|