"""PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8. Uses segmentation masks to extract color, tissue, morphological, and texture features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers. Includes Fitzpatrick-aware debiasing correction factors. """ import numpy as np import cv2 import joblib import json from dataclasses import dataclass, field from typing import Optional from pathlib import Path ITEMS = [3, 4, 5, 6, 7, 8] ITEM_NAMES = { 3: "Necrotic Type", 4: "Necrotic Amount", 5: "Granulation Type", 6: "Granulation Amount", 7: "Edges", 8: "Periulcer Skin", } CORRECTION_FACTORS = { "I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0}, "II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0}, "III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1}, "IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3}, "V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6}, "VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9}, } @dataclass class PWATResult: scores_raw: dict = field(default_factory=dict) scores_adjusted: dict = field(default_factory=dict) total_raw: int = 0 total_adjusted: float = 0.0 fitzpatrick_type: str = "" features: dict = field(default_factory=dict) def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]: """Extract features from the wound region for PWAT prediction.""" b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127 npx = int(np.sum(b)) if npx < 50: return None feats = {} hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32) lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32) rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]: for ci, cn in enumerate(names): vals = arr[b, ci] feats[f"{cs}_{cn}_mean"] = float(np.mean(vals)) feats[f"{cs}_{cn}_std"] = float(np.std(vals)) feats[f"{cs}_{cn}_median"] = float(np.median(vals)) feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25)) feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75)) h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2] l_ch = lab[b, 0] * (100 / 255) a_ch = lab[b, 1] - 128 eschar = ((v < 100) & (s < 60)) | (v < 60) slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100) feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100) feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100) feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100) feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"] mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8) cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if cnts: cnt = max(cnts, key=cv2.contourArea) area = cv2.contourArea(cnt) perim = cv2.arcLength(cnt, True) circ = 4 * np.pi * area / (perim ** 2) if perim > 0 else 0 feats["morph_area"] = float(area) feats["morph_perimeter"] = float(perim) feats["morph_circularity"] = float(circ) feats["morph_irregularity"] = float(1 - circ) x, y, w2, h2 = cv2.boundingRect(cnt) feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8)) feats["morph_extent"] = float(area / (w2 * h2 + 1e-8)) hull = cv2.convexHull(cnt) feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8)) gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) wound_gray = gray[b] feats["texture_mean"] = float(np.mean(wound_gray)) feats["texture_std"] = float(np.std(wound_gray)) hist_vals = np.histogram(wound_gray, bins=64, density=True)[0] feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10))) dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8)) eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8)) edge_zone = (dilated - eroded) > 127 if np.any(edge_zone): feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone]))) feats["wound_npx"] = float(npx) feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1])) return feats class PWATPredictor: """Predicts PWAT items 3-8 using trained XGBoost models.""" def __init__(self, models_dir: str): self.models = {} models_path = Path(models_dir) # Load the selected feature columns (30 features after variance+correlation filter) features_json = models_path / "selected_features.json" if features_json.exists(): with open(features_json) as f: self.selected_features = json.load(f) print(f"PWAT: Loaded {len(self.selected_features)} selected features from JSON") else: self.selected_features = None print("PWAT: WARNING — selected_features.json not found, using all features") for item in ITEMS: pkl = models_path / f"xgb_pwat{item}.pkl" if pkl.exists(): self.models[item] = joblib.load(pkl) def predict( self, img_bgr: np.ndarray, ulcer_mask: np.ndarray, fitzpatrick_type: str = "III", ) -> PWATResult: """Predict PWAT scores for a single image.""" feats = extract_features(img_bgr, ulcer_mask) if feats is None: return PWATResult(fitzpatrick_type=fitzpatrick_type) # Use the exact 30 features from training (order matters) if self.selected_features: cols = self.selected_features else: cols = sorted(feats.keys()) X = np.array([[feats.get(c, 0.0) for c in cols]]) scores_raw = {} scores_adj = {} for item in ITEMS: if item not in self.models: scores_raw[item] = 0 scores_adj[item] = 0.0 continue pred = int(self.models[item].predict(X)[0]) scores_raw[item] = pred factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0) scores_adj[item] = float(np.clip(pred + factor, 0, 4)) return PWATResult( scores_raw=scores_raw, scores_adjusted=scores_adj, total_raw=sum(scores_raw.values()), total_adjusted=round(sum(scores_adj.values()), 1), fitzpatrick_type=fitzpatrick_type, features=feats, )