mmarquezsa commited on
Commit
21b7fe4
·
verified ·
1 Parent(s): 354bfe2

Fix: extract feature names from XGBoost model to match training feature selection (30 vs 63)

Browse files
Files changed (1) hide show
  1. src/pwat_estimator.py +36 -36
src/pwat_estimator.py CHANGED
@@ -23,7 +23,6 @@ ITEM_NAMES = {
23
  }
24
 
25
  # Debiasing correction factors (calibrated from 61 DFU images)
26
- # Applied as: adjusted = clip(raw + factor, 0, 4)
27
  CORRECTION_FACTORS = {
28
  "I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
29
  "II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
@@ -36,8 +35,8 @@ CORRECTION_FACTORS = {
36
 
37
  @dataclass
38
  class PWATResult:
39
- scores_raw: dict = field(default_factory=dict) # {item: int}
40
- scores_adjusted: dict = field(default_factory=dict) # {item: float} (debiased)
41
  total_raw: int = 0
42
  total_adjusted: float = 0.0
43
  fitzpatrick_type: str = ""
@@ -45,10 +44,7 @@ class PWATResult:
45
 
46
 
47
  def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
48
- """Extract 63 features from the wound region for PWAT prediction.
49
-
50
- Features: color (RGB/HSV/Lab), tissue composition, morphology, texture.
51
- """
52
  b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
53
  npx = int(np.sum(b))
54
  if npx < 50:
@@ -56,7 +52,7 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
56
 
57
  feats = {}
58
 
59
- # --- Color features (45) ---
60
  hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
61
  lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
62
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
@@ -70,7 +66,7 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
70
  feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
71
  feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
72
 
73
- # --- Tissue composition (5) ---
74
  h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
75
  l_ch = lab[b, 0] * (100 / 255)
76
  a_ch = lab[b, 1] - 128
@@ -86,14 +82,14 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
86
  feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
87
  feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
88
 
89
- # --- Morphological features (7) ---
90
  mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
91
  cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
92
  if cnts:
93
  cnt = max(cnts, key=cv2.contourArea)
94
  area = cv2.contourArea(cnt)
95
  perim = cv2.arcLength(cnt, True)
96
- circ = 4 * np.pi * area / (perim**2) if perim > 0 else 0
97
  feats["morph_area"] = float(area)
98
  feats["morph_perimeter"] = float(perim)
99
  feats["morph_circularity"] = float(circ)
@@ -104,7 +100,7 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
104
  hull = cv2.convexHull(cnt)
105
  feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
106
 
107
- # --- Texture features (4) ---
108
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
109
  wound_gray = gray[b]
110
  feats["texture_mean"] = float(np.mean(wound_gray))
@@ -118,7 +114,7 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
118
  if np.any(edge_zone):
119
  feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
120
 
121
- # --- ROI features (2) ---
122
  feats["wound_npx"] = float(npx)
123
  feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
124
 
@@ -126,52 +122,56 @@ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[di
126
 
127
 
128
  class PWATPredictor:
129
- """Predicts PWAT items 3-8 from wound features using trained XGBoost models."""
130
 
131
  def __init__(self, models_dir: str):
132
  self.models = {}
 
133
  models_path = Path(models_dir)
134
  for item in ITEMS:
135
  pkl = models_path / f"xgb_pwat{item}.pkl"
136
  if pkl.exists():
137
- self.models[item] = joblib.load(pkl)
 
 
 
 
 
 
 
 
138
 
139
  def predict(
140
  self,
141
  img_bgr: np.ndarray,
142
  ulcer_mask: np.ndarray,
143
  fitzpatrick_type: str = "III",
144
- feature_cols: Optional[list] = None,
145
  ) -> PWATResult:
146
- """Predict PWAT scores for a single image.
147
-
148
- Args:
149
- img_bgr: BGR image
150
- ulcer_mask: Binary ulcer mask (H, W)
151
- fitzpatrick_type: Fitzpatrick type for debiasing ("I" .. "VI")
152
- feature_cols: Ordered feature column names (must match training order).
153
- If None, uses all extracted features sorted alphabetically.
154
- """
155
  feats = extract_features(img_bgr, ulcer_mask)
156
  if feats is None:
157
  return PWATResult(fitzpatrick_type=fitzpatrick_type)
158
 
159
- # Build feature vector
160
- if feature_cols is None:
161
- feature_cols = sorted(feats.keys())
162
- X = np.array([[feats.get(c, 0.0) for c in feature_cols]])
163
-
164
  scores_raw = {}
165
  scores_adj = {}
166
  for item in ITEMS:
167
- if item in self.models:
168
- pred = int(self.models[item].predict(X)[0])
169
- scores_raw[item] = pred
170
- factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
171
- scores_adj[item] = float(np.clip(pred + factor, 0, 4))
172
- else:
173
  scores_raw[item] = 0
174
  scores_adj[item] = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  return PWATResult(
177
  scores_raw=scores_raw,
 
23
  }
24
 
25
  # Debiasing correction factors (calibrated from 61 DFU images)
 
26
  CORRECTION_FACTORS = {
27
  "I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
28
  "II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
 
35
 
36
  @dataclass
37
  class PWATResult:
38
+ scores_raw: dict = field(default_factory=dict)
39
+ scores_adjusted: dict = field(default_factory=dict)
40
  total_raw: int = 0
41
  total_adjusted: float = 0.0
42
  fitzpatrick_type: str = ""
 
44
 
45
 
46
  def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
47
+ """Extract features from the wound region for PWAT prediction."""
 
 
 
48
  b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
49
  npx = int(np.sum(b))
50
  if npx < 50:
 
52
 
53
  feats = {}
54
 
55
+ # Color features (45)
56
  hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
57
  lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
58
  rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
 
66
  feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
67
  feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
68
 
69
+ # Tissue composition (5)
70
  h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
71
  l_ch = lab[b, 0] * (100 / 255)
72
  a_ch = lab[b, 1] - 128
 
82
  feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
83
  feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
84
 
85
+ # Morphological features (7)
86
  mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
87
  cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
88
  if cnts:
89
  cnt = max(cnts, key=cv2.contourArea)
90
  area = cv2.contourArea(cnt)
91
  perim = cv2.arcLength(cnt, True)
92
+ circ = 4 * np.pi * area / (perim ** 2) if perim > 0 else 0
93
  feats["morph_area"] = float(area)
94
  feats["morph_perimeter"] = float(perim)
95
  feats["morph_circularity"] = float(circ)
 
100
  hull = cv2.convexHull(cnt)
101
  feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
102
 
103
+ # Texture features (4)
104
  gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
105
  wound_gray = gray[b]
106
  feats["texture_mean"] = float(np.mean(wound_gray))
 
114
  if np.any(edge_zone):
115
  feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
116
 
117
+ # ROI features (2)
118
  feats["wound_npx"] = float(npx)
119
  feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
120
 
 
122
 
123
 
124
  class PWATPredictor:
125
+ """Predicts PWAT items 3-8 using trained XGBoost models."""
126
 
127
  def __init__(self, models_dir: str):
128
  self.models = {}
129
+ self.feature_names = {}
130
  models_path = Path(models_dir)
131
  for item in ITEMS:
132
  pkl = models_path / f"xgb_pwat{item}.pkl"
133
  if pkl.exists():
134
+ model = joblib.load(pkl)
135
+ self.models[item] = model
136
+ # Extract expected feature names from the trained model
137
+ try:
138
+ names = model.get_booster().feature_names
139
+ if names:
140
+ self.feature_names[item] = names
141
+ except Exception:
142
+ pass
143
 
144
  def predict(
145
  self,
146
  img_bgr: np.ndarray,
147
  ulcer_mask: np.ndarray,
148
  fitzpatrick_type: str = "III",
 
149
  ) -> PWATResult:
150
+ """Predict PWAT scores for a single image."""
 
 
 
 
 
 
 
 
151
  feats = extract_features(img_bgr, ulcer_mask)
152
  if feats is None:
153
  return PWATResult(fitzpatrick_type=fitzpatrick_type)
154
 
 
 
 
 
 
155
  scores_raw = {}
156
  scores_adj = {}
157
  for item in ITEMS:
158
+ if item not in self.models:
 
 
 
 
 
159
  scores_raw[item] = 0
160
  scores_adj[item] = 0.0
161
+ continue
162
+
163
+ # Use model's expected feature names if available
164
+ if item in self.feature_names:
165
+ cols = self.feature_names[item]
166
+ else:
167
+ cols = sorted(feats.keys())
168
+
169
+ X = np.array([[feats.get(c, 0.0) for c in cols]])
170
+
171
+ pred = int(self.models[item].predict(X)[0])
172
+ scores_raw[item] = pred
173
+ factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
174
+ scores_adj[item] = float(np.clip(pred + factor, 0, 4))
175
 
176
  return PWATResult(
177
  scores_raw=scores_raw,