File size: 7,024 Bytes
21ccfaf fee7e0c 21ccfaf 21b7fe4 21ccfaf 21b7fe4 21ccfaf 21b7fe4 21ccfaf 21b7fe4 21ccfaf fee7e0c 21ccfaf fee7e0c 21ccfaf 21b7fe4 21ccfaf fee7e0c 21ccfaf 21b7fe4 21ccfaf 21b7fe4 21ccfaf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8.
Uses segmentation masks to extract color, tissue, morphological, and texture
features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers.
Includes Fitzpatrick-aware debiasing correction factors.
"""
import numpy as np
import cv2
import joblib
import json
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
ITEMS = [3, 4, 5, 6, 7, 8]
ITEM_NAMES = {
3: "Necrotic Type",
4: "Necrotic Amount",
5: "Granulation Type",
6: "Granulation Amount",
7: "Edges",
8: "Periulcer Skin",
}
CORRECTION_FACTORS = {
"I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
"II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
"III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1},
"IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3},
"V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6},
"VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9},
}
@dataclass
class PWATResult:
scores_raw: dict = field(default_factory=dict)
scores_adjusted: dict = field(default_factory=dict)
total_raw: int = 0
total_adjusted: float = 0.0
fitzpatrick_type: str = ""
features: dict = field(default_factory=dict)
def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
"""Extract features from the wound region for PWAT prediction."""
b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
npx = int(np.sum(b))
if npx < 50:
return None
feats = {}
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]:
for ci, cn in enumerate(names):
vals = arr[b, ci]
feats[f"{cs}_{cn}_mean"] = float(np.mean(vals))
feats[f"{cs}_{cn}_std"] = float(np.std(vals))
feats[f"{cs}_{cn}_median"] = float(np.median(vals))
feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
l_ch = lab[b, 0] * (100 / 255)
a_ch = lab[b, 1] - 128
eschar = ((v < 100) & (s < 60)) | (v < 60)
slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar
gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar
necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran
feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100)
feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100)
feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100)
feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if cnts:
cnt = max(cnts, key=cv2.contourArea)
area = cv2.contourArea(cnt)
perim = cv2.arcLength(cnt, True)
circ = 4 * np.pi * area / (perim ** 2) if perim > 0 else 0
feats["morph_area"] = float(area)
feats["morph_perimeter"] = float(perim)
feats["morph_circularity"] = float(circ)
feats["morph_irregularity"] = float(1 - circ)
x, y, w2, h2 = cv2.boundingRect(cnt)
feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8))
feats["morph_extent"] = float(area / (w2 * h2 + 1e-8))
hull = cv2.convexHull(cnt)
feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
wound_gray = gray[b]
feats["texture_mean"] = float(np.mean(wound_gray))
feats["texture_std"] = float(np.std(wound_gray))
hist_vals = np.histogram(wound_gray, bins=64, density=True)[0]
feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10)))
dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8))
eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8))
edge_zone = (dilated - eroded) > 127
if np.any(edge_zone):
feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
feats["wound_npx"] = float(npx)
feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
return feats
class PWATPredictor:
"""Predicts PWAT items 3-8 using trained XGBoost models."""
def __init__(self, models_dir: str):
self.models = {}
models_path = Path(models_dir)
# Load the selected feature columns (30 features after variance+correlation filter)
features_json = models_path / "selected_features.json"
if features_json.exists():
with open(features_json) as f:
self.selected_features = json.load(f)
print(f"PWAT: Loaded {len(self.selected_features)} selected features from JSON")
else:
self.selected_features = None
print("PWAT: WARNING — selected_features.json not found, using all features")
for item in ITEMS:
pkl = models_path / f"xgb_pwat{item}.pkl"
if pkl.exists():
self.models[item] = joblib.load(pkl)
def predict(
self,
img_bgr: np.ndarray,
ulcer_mask: np.ndarray,
fitzpatrick_type: str = "III",
) -> PWATResult:
"""Predict PWAT scores for a single image."""
feats = extract_features(img_bgr, ulcer_mask)
if feats is None:
return PWATResult(fitzpatrick_type=fitzpatrick_type)
# Use the exact 30 features from training (order matters)
if self.selected_features:
cols = self.selected_features
else:
cols = sorted(feats.keys())
X = np.array([[feats.get(c, 0.0) for c in cols]])
scores_raw = {}
scores_adj = {}
for item in ITEMS:
if item not in self.models:
scores_raw[item] = 0
scores_adj[item] = 0.0
continue
pred = int(self.models[item].predict(X)[0])
scores_raw[item] = pred
factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
scores_adj[item] = float(np.clip(pred + factor, 0, 4))
return PWATResult(
scores_raw=scores_raw,
scores_adjusted=scores_adj,
total_raw=sum(scores_raw.values()),
total_adjusted=round(sum(scores_adj.values()), 1),
fitzpatrick_type=fitzpatrick_type,
features=feats,
)
|