WoundNetB7-DFU-Analysis / src /pwat_estimator.py
mmarquezsa's picture
Fix: load 30 selected features from JSON instead of guessing from model
fee7e0c verified
"""PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8.
Uses segmentation masks to extract color, tissue, morphological, and texture
features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers.
Includes Fitzpatrick-aware debiasing correction factors.
"""
import numpy as np
import cv2
import joblib
import json
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
ITEMS = [3, 4, 5, 6, 7, 8]
ITEM_NAMES = {
3: "Necrotic Type",
4: "Necrotic Amount",
5: "Granulation Type",
6: "Granulation Amount",
7: "Edges",
8: "Periulcer Skin",
}
CORRECTION_FACTORS = {
"I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
"II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
"III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1},
"IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3},
"V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6},
"VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9},
}
@dataclass
class PWATResult:
scores_raw: dict = field(default_factory=dict)
scores_adjusted: dict = field(default_factory=dict)
total_raw: int = 0
total_adjusted: float = 0.0
fitzpatrick_type: str = ""
features: dict = field(default_factory=dict)
def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
"""Extract features from the wound region for PWAT prediction."""
b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
npx = int(np.sum(b))
if npx < 50:
return None
feats = {}
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]:
for ci, cn in enumerate(names):
vals = arr[b, ci]
feats[f"{cs}_{cn}_mean"] = float(np.mean(vals))
feats[f"{cs}_{cn}_std"] = float(np.std(vals))
feats[f"{cs}_{cn}_median"] = float(np.median(vals))
feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
l_ch = lab[b, 0] * (100 / 255)
a_ch = lab[b, 1] - 128
eschar = ((v < 100) & (s < 60)) | (v < 60)
slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar
gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar
necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran
feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100)
feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100)
feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100)
feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if cnts:
cnt = max(cnts, key=cv2.contourArea)
area = cv2.contourArea(cnt)
perim = cv2.arcLength(cnt, True)
circ = 4 * np.pi * area / (perim ** 2) if perim > 0 else 0
feats["morph_area"] = float(area)
feats["morph_perimeter"] = float(perim)
feats["morph_circularity"] = float(circ)
feats["morph_irregularity"] = float(1 - circ)
x, y, w2, h2 = cv2.boundingRect(cnt)
feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8))
feats["morph_extent"] = float(area / (w2 * h2 + 1e-8))
hull = cv2.convexHull(cnt)
feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
wound_gray = gray[b]
feats["texture_mean"] = float(np.mean(wound_gray))
feats["texture_std"] = float(np.std(wound_gray))
hist_vals = np.histogram(wound_gray, bins=64, density=True)[0]
feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10)))
dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8))
eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8))
edge_zone = (dilated - eroded) > 127
if np.any(edge_zone):
feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
feats["wound_npx"] = float(npx)
feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
return feats
class PWATPredictor:
"""Predicts PWAT items 3-8 using trained XGBoost models."""
def __init__(self, models_dir: str):
self.models = {}
models_path = Path(models_dir)
# Load the selected feature columns (30 features after variance+correlation filter)
features_json = models_path / "selected_features.json"
if features_json.exists():
with open(features_json) as f:
self.selected_features = json.load(f)
print(f"PWAT: Loaded {len(self.selected_features)} selected features from JSON")
else:
self.selected_features = None
print("PWAT: WARNING — selected_features.json not found, using all features")
for item in ITEMS:
pkl = models_path / f"xgb_pwat{item}.pkl"
if pkl.exists():
self.models[item] = joblib.load(pkl)
def predict(
self,
img_bgr: np.ndarray,
ulcer_mask: np.ndarray,
fitzpatrick_type: str = "III",
) -> PWATResult:
"""Predict PWAT scores for a single image."""
feats = extract_features(img_bgr, ulcer_mask)
if feats is None:
return PWATResult(fitzpatrick_type=fitzpatrick_type)
# Use the exact 30 features from training (order matters)
if self.selected_features:
cols = self.selected_features
else:
cols = sorted(feats.keys())
X = np.array([[feats.get(c, 0.0) for c in cols]])
scores_raw = {}
scores_adj = {}
for item in ITEMS:
if item not in self.models:
scores_raw[item] = 0
scores_adj[item] = 0.0
continue
pred = int(self.models[item].predict(X)[0])
scores_raw[item] = pred
factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
scores_adj[item] = float(np.clip(pred + factor, 0, 4))
return PWATResult(
scores_raw=scores_raw,
scores_adjusted=scores_adj,
total_raw=sum(scores_raw.values()),
total_adjusted=round(sum(scores_adj.values()), 1),
fitzpatrick_type=fitzpatrick_type,
features=feats,
)