mmarquezsa commited on
Commit
1b5f5d8
·
verified ·
1 Parent(s): 04c9d0a

feat: upgrade pipeline UI — binary seg + multiclass + PWAT raw vs adjusted

Browse files
Files changed (1) hide show
  1. pipeline.py +34 -12
pipeline.py CHANGED
@@ -23,6 +23,9 @@ class AnalysisResult:
23
  """Complete DFU analysis result from a single image."""
24
  # Segmentation
25
  class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
 
 
 
26
  # Fitzpatrick
27
  fitzpatrick: Optional[FitzpatrickResult] = None
28
  # PWAT
@@ -147,23 +150,42 @@ class WoundNetB7Pipeline:
147
 
148
  return AnalysisResult(
149
  class_distribution=class_dist,
 
 
 
150
  fitzpatrick=fitz,
151
  pwat=pwat,
152
  image_size=(h, w),
153
  device=str(self.device),
154
  )
155
 
156
- def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
157
- """Create overlay visualization of segmentation result."""
158
- h, w = img_bgr.shape[:2]
159
- seg = segment(self.seg_model, img_bgr, self.device, use_tta=False)
160
- classmap = seg["classmap"]
161
-
 
 
 
 
 
 
 
 
 
 
 
 
162
  overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
163
- for cid, color in CLASS_COLORS.items():
164
- if cid == 0:
165
- continue
166
- mask = classmap == cid
167
- overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
168
-
169
  return overlay.astype(np.uint8)
 
 
 
 
 
23
  """Complete DFU analysis result from a single image."""
24
  # Segmentation
25
  class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
26
+ classmap: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
27
+ probs: Optional[np.ndarray] = field(default=None, repr=False) # (4,H,W) float32
28
+ ulcer_mask: Optional[np.ndarray] = field(default=None, repr=False) # (H,W) uint8
29
  # Fitzpatrick
30
  fitzpatrick: Optional[FitzpatrickResult] = None
31
  # PWAT
 
150
 
151
  return AnalysisResult(
152
  class_distribution=class_dist,
153
+ classmap=classmap,
154
+ probs=seg["probs"],
155
+ ulcer_mask=ulcer_mask,
156
  fitzpatrick=fitz,
157
  pwat=pwat,
158
  image_size=(h, w),
159
  device=str(self.device),
160
  )
161
 
162
+ def visualize_binary(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
163
+ """Create binary ulcer segmentation overlay (ulcer only, red mask)."""
164
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
165
+ overlay = img_rgb.copy()
166
+ if result.ulcer_mask is not None:
167
+ ulcer_bool = result.ulcer_mask > 127
168
+ if np.any(ulcer_bool):
169
+ overlay[ulcer_bool] = overlay[ulcer_bool] * 0.4 + np.array([255, 0, 0], dtype=np.float32) * 0.6
170
+ overlay_u8 = np.clip(overlay, 0, 255).astype(np.uint8)
171
+ contours, _ = cv2.findContours(
172
+ ulcer_bool.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
173
+ )
174
+ cv2.drawContours(overlay_u8, contours, -1, (255, 255, 255), 2)
175
+ return overlay_u8
176
+ return np.clip(overlay, 0, 255).astype(np.uint8)
177
+
178
+ def visualize_multiclass(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
179
+ """Create multiclass segmentation overlay using cached classmap."""
180
  overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
181
+ if result.classmap is not None:
182
+ for cid, color in CLASS_COLORS.items():
183
+ if cid == 0:
184
+ continue
185
+ mask = result.classmap == cid
186
+ overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
187
  return overlay.astype(np.uint8)
188
+
189
+ def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
190
+ """Create overlay visualization (backward compatible)."""
191
+ return self.visualize_multiclass(img_bgr, result)