| """PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8. |
| |
| Uses segmentation masks to extract color, tissue, morphological, and texture |
| features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers. |
| |
| Includes Fitzpatrick-aware debiasing correction factors. |
| """ |
| import numpy as np |
| import cv2 |
| import joblib |
| import json |
| from dataclasses import dataclass, field |
| from typing import Optional |
| from pathlib import Path |
|
|
| ITEMS = [3, 4, 5, 6, 7, 8] |
| ITEM_NAMES = { |
| 3: "Necrotic Type", |
| 4: "Necrotic Amount", |
| 5: "Granulation Type", |
| 6: "Granulation Amount", |
| 7: "Edges", |
| 8: "Periulcer Skin", |
| } |
|
|
| CORRECTION_FACTORS = { |
| "I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0}, |
| "II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0}, |
| "III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1}, |
| "IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3}, |
| "V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6}, |
| "VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9}, |
| } |
|
|
|
|
| @dataclass |
| class PWATResult: |
| scores_raw: dict = field(default_factory=dict) |
| scores_adjusted: dict = field(default_factory=dict) |
| total_raw: int = 0 |
| total_adjusted: float = 0.0 |
| fitzpatrick_type: str = "" |
| features: dict = field(default_factory=dict) |
|
|
|
|
| def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]: |
| """Extract features from the wound region for PWAT prediction.""" |
| b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127 |
| npx = int(np.sum(b)) |
| if npx < 50: |
| return None |
|
|
| feats = {} |
|
|
| hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32) |
| lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32) |
| rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) |
|
|
| for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]: |
| for ci, cn in enumerate(names): |
| vals = arr[b, ci] |
| feats[f"{cs}_{cn}_mean"] = float(np.mean(vals)) |
| feats[f"{cs}_{cn}_std"] = float(np.std(vals)) |
| feats[f"{cs}_{cn}_median"] = float(np.median(vals)) |
| feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25)) |
| feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75)) |
|
|
| h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2] |
| l_ch = lab[b, 0] * (100 / 255) |
| a_ch = lab[b, 1] - 128 |
|
|
| eschar = ((v < 100) & (s < 60)) | (v < 60) |
| slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar |
| gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar |
| necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran |
|
|
| feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100) |
| feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100) |
| feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100) |
| feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100) |
| feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"] |
|
|
| mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8) |
| cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if cnts: |
| cnt = max(cnts, key=cv2.contourArea) |
| area = cv2.contourArea(cnt) |
| perim = cv2.arcLength(cnt, True) |
| circ = 4 * np.pi * area / (perim ** 2) if perim > 0 else 0 |
| feats["morph_area"] = float(area) |
| feats["morph_perimeter"] = float(perim) |
| feats["morph_circularity"] = float(circ) |
| feats["morph_irregularity"] = float(1 - circ) |
| x, y, w2, h2 = cv2.boundingRect(cnt) |
| feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8)) |
| feats["morph_extent"] = float(area / (w2 * h2 + 1e-8)) |
| hull = cv2.convexHull(cnt) |
| feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8)) |
|
|
| gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) |
| wound_gray = gray[b] |
| feats["texture_mean"] = float(np.mean(wound_gray)) |
| feats["texture_std"] = float(np.std(wound_gray)) |
| hist_vals = np.histogram(wound_gray, bins=64, density=True)[0] |
| feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10))) |
|
|
| dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8)) |
| eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8)) |
| edge_zone = (dilated - eroded) > 127 |
| if np.any(edge_zone): |
| feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone]))) |
|
|
| feats["wound_npx"] = float(npx) |
| feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1])) |
|
|
| return feats |
|
|
|
|
| class PWATPredictor: |
| """Predicts PWAT items 3-8 using trained XGBoost models.""" |
|
|
| def __init__(self, models_dir: str): |
| self.models = {} |
| models_path = Path(models_dir) |
|
|
| |
| features_json = models_path / "selected_features.json" |
| if features_json.exists(): |
| with open(features_json) as f: |
| self.selected_features = json.load(f) |
| print(f"PWAT: Loaded {len(self.selected_features)} selected features from JSON") |
| else: |
| self.selected_features = None |
| print("PWAT: WARNING — selected_features.json not found, using all features") |
|
|
| for item in ITEMS: |
| pkl = models_path / f"xgb_pwat{item}.pkl" |
| if pkl.exists(): |
| self.models[item] = joblib.load(pkl) |
|
|
| def predict( |
| self, |
| img_bgr: np.ndarray, |
| ulcer_mask: np.ndarray, |
| fitzpatrick_type: str = "III", |
| ) -> PWATResult: |
| """Predict PWAT scores for a single image.""" |
| feats = extract_features(img_bgr, ulcer_mask) |
| if feats is None: |
| return PWATResult(fitzpatrick_type=fitzpatrick_type) |
|
|
| |
| if self.selected_features: |
| cols = self.selected_features |
| else: |
| cols = sorted(feats.keys()) |
|
|
| X = np.array([[feats.get(c, 0.0) for c in cols]]) |
|
|
| scores_raw = {} |
| scores_adj = {} |
| for item in ITEMS: |
| if item not in self.models: |
| scores_raw[item] = 0 |
| scores_adj[item] = 0.0 |
| continue |
|
|
| pred = int(self.models[item].predict(X)[0]) |
| scores_raw[item] = pred |
| factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0) |
| scores_adj[item] = float(np.clip(pred + factor, 0, 4)) |
|
|
| return PWATResult( |
| scores_raw=scores_raw, |
| scores_adjusted=scores_adj, |
| total_raw=sum(scores_raw.values()), |
| total_adjusted=round(sum(scores_adj.values()), 1), |
| fitzpatrick_type=fitzpatrick_type, |
| features=feats, |
| ) |
|
|