File size: 8,942 Bytes
21ccfaf 7b081d1 21ccfaf 1b5f5d8 21ccfaf 7b081d1 21ccfaf 1b5f5d8 21ccfaf 1b5f5d8 21ccfaf 1b5f5d8 21ccfaf 1b5f5d8 7b081d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | """WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA.
Usage:
from pipeline import WoundNetB7Pipeline
pipe = WoundNetB7Pipeline("models/")
result = pipe.analyze("path/to/dfu_image.png")
print(result)
"""
import torch
import numpy as np
import cv2
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional
from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS
from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult
from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES
from src.integrated_report import render_integrated_report
@dataclass
class AnalysisResult:
"""Complete DFU analysis result from a single image."""
# Segmentation
class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
classmap: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
probs: Optional[np.ndarray] = field(default=None, repr=False) # (4,H,W) float32
ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
# Fitzpatrick
fitzpatrick: Optional[FitzpatrickResult] = None
# PWAT
pwat: Optional[PWATResult] = None
# Metadata
image_size: tuple = (0, 0)
device: str = "cpu"
def summary(self) -> str:
lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50]
lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}")
lines.append(f"Device: {self.device}")
lines.append("\n--- Segmentation ---")
for cls, pct in self.class_distribution.items():
lines.append(f" {cls:<15s}: {pct:5.1f}%")
if self.fitzpatrick:
f = self.fitzpatrick
lines.append("\n--- Fitzpatrick / ITA ---")
lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})")
lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}")
lines.append(f" L* mean: {f.l_skin_mean:.1f}")
lines.append(f" Confidence: {f.confidence:.2f}")
lines.append(f" Pixels: {f.healthy_pixels:,}")
if self.pwat and self.pwat.scores_raw:
p = self.pwat
lines.append("\n--- PWAT Scores (Items 3-8) ---")
lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}")
lines.append(" " + "-" * 33)
for item in [3, 4, 5, 6, 7, 8]:
name = ITEM_NAMES.get(item, f"Item {item}")
raw = p.scores_raw.get(item, "-")
adj = p.scores_adjusted.get(item, "-")
lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}")
lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}")
lines.append(f" Fitzpatrick correction applied for type: {p.fitzpatrick_type}")
return "\n".join(lines)
def to_dict(self) -> dict:
d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution}
if self.fitzpatrick:
d["fitzpatrick"] = asdict(self.fitzpatrick)
if self.pwat:
d["pwat"] = {
"scores_raw": self.pwat.scores_raw,
"scores_adjusted": self.pwat.scores_adjusted,
"total_raw": self.pwat.total_raw,
"total_adjusted": self.pwat.total_adjusted,
}
return d
class WoundNetB7Pipeline:
"""End-to-end DFU analysis pipeline.
Args:
models_dir: Path to models/ directory containing:
- segmentation/WoundNetB7_proposed_best.pt
- pwat/xgb_pwat{3-8}.pkl
device: "cuda" or "cpu" (auto-detected if None)
use_tta: Use test-time augmentation for segmentation (slower but more accurate)
"""
def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True):
self.models_dir = Path(models_dir)
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.use_tta = use_tta
# Load segmentation model
seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt"
self.seg_model = load_segmentation_model(str(seg_path), self.device)
print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)")
# Load PWAT predictor
pwat_path = self.models_dir / "pwat"
self.pwat_predictor = PWATPredictor(str(pwat_path))
print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)")
print(f"Pipeline ready on {self.device}")
def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult:
"""Analyze a DFU image end-to-end.
Args:
image_input: file path (str/Path), BGR numpy array, or RGB numpy array
use_tta: Override TTA setting for this call
Returns:
AnalysisResult with segmentation, Fitzpatrick, and PWAT data.
"""
tta = use_tta if use_tta is not None else self.use_tta
# Load image
if isinstance(image_input, (str, Path)):
img_bgr = cv2.imread(str(image_input))
if img_bgr is None:
raise FileNotFoundError(f"Cannot read image: {image_input}")
elif isinstance(image_input, np.ndarray):
img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
else:
raise TypeError(f"Unsupported input type: {type(image_input)}")
h, w = img_bgr.shape[:2]
# Step 1: Segmentation
seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta)
classmap = seg["classmap"]
class_dist = {}
for cid, name in CLASS_NAMES.items():
class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1)
# Step 2: Fitzpatrick estimation (from healthy skin)
fitz = estimate_fitzpatrick(img_bgr, seg["masks"])
# Step 3: PWAT prediction (from ulcer mask)
ulcer_mask = (classmap == 3).astype(np.uint8) * 255
pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type)
return AnalysisResult(
class_distribution=class_dist,
classmap=classmap,
probs=seg["probs"],
ulcer_mask=ulcer_mask,
fitzpatrick=fitz,
pwat=pwat,
image_size=(h, w),
device=str(self.device),
)
def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create binary ulcer segmentation overlay (ulcer only, red mask)."""
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
overlay = img_rgb.copy()
if result.ulcer_mask is not None:
ulcer_bool = result.ulcer_mask > 127
if np.any(ulcer_bool):
overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6
overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8)
contours, _ = cv2.findContours(
ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2)
return overlay_u8
return np.clip(overlay, 0, 255).astype(np.uint8)
def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create multiclass segmentation overlay using cached classmap."""
overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
if result.classmap is not None:
for cid, color in CLASS_COLORS.items():
if cid == 0:
continue
mask = result.classmap == cid
overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
return overlay.astype(np.uint8)
def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Create overlay visualization (backward compatible)."""
return self.visualize_multiclass(img_bgr, result)
def render_integrated_report(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
"""Render a single-image integrated clinical dashboard (1920x1200 RGB).
Combines segmentation, class distribution, Fitzpatrick/ITA estimation
and PWAT scoring (raw + adjusted) into one nurse-facing report.
"""
original_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
binary_overlay = self.visualize_binary(img_bgr, result)
multi_overlay = self.visualize_multiclass(img_bgr, result)
return render_integrated_report(original_rgb, binary_overlay, multi_overlay, result)
|