mmarquezsa commited on
Commit
21ccfaf
·
verified ·
1 Parent(s): b662645

Add pipeline code, PWAT models, and Gradio app

Browse files
README.md CHANGED
@@ -1,12 +1,105 @@
1
- ---
2
- title: WoundNetB7 DFU Analysis
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.11.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WoundNetB7 — Automated DFU Assessment Pipeline
2
+
3
+ End-to-end pipeline for Diabetic Foot Ulcer (DFU) analysis:
4
+
5
+ **Image** -> **Multiclass Segmentation** -> **PWAT Scoring** + **Fitzpatrick/ITA Estimation**
6
+
7
+ ## Pipeline
8
+
9
+ ```
10
+ Input Image (DFU photograph)
11
+ |
12
+ v
13
+ [1] WoundNetB7 Segmentation (EfficientNet-B7 + ASPP + CBAM + TAM)
14
+ -> 4-class masks: background, foot, perilesion, ulcer
15
+ -> Ulcer Dice: 0.927 (95% CI: [0.917, 0.936])
16
+ |
17
+ +---> [2] PWAT Estimation (XGBoost per item)
18
+ | -> Items 3-8 scores (0-4 ordinal)
19
+ | -> 5-fold CV: MAE 0.61-0.87, Adjacent Match 75-89%
20
+ |
21
+ +---> [3] Fitzpatrick/ITA Estimation
22
+ | -> Healthy skin = foot - perilesion - ulcer
23
+ | -> ITA angle + Fitzpatrick I-VI classification
24
+ | -> Calibrated on 61 DFU images (86.9% exact match)
25
+ |
26
+ +---> [4] Debiasing
27
+ -> PWAT scores adjusted by Fitzpatrick type
28
+ -> 18% max group gap reduction (p < 10^-27)
29
+ ```
30
+
31
+ ## Quick Start
32
+
33
+ ```python
34
+ from pipeline import WoundNetB7Pipeline
35
+
36
+ pipe = WoundNetB7Pipeline("models/")
37
+ result = pipe.analyze("path/to/dfu_image.png")
38
+ print(result.summary())
39
+ ```
40
+
41
+ ## Models
42
+
43
+ | Component | Architecture | Metric | Value |
44
+ |-----------|-------------|--------|-------|
45
+ | Segmentation | EfficientNet-B7 + UNet + ASPP + CBAM + TAM | Ulcer Dice | 0.927 |
46
+ | PWAT Item 3 | XGBoost + SMOTE | MAE / Adjacent | 0.87 / 74.9% |
47
+ | PWAT Item 7 | XGBoost + SMOTE | MAE / Adjacent | 0.61 / 89.3% |
48
+ | Fitzpatrick | ITA calibrated thresholds | Exact match | 86.9% |
49
+
50
+ ## Fitzpatrick/ITA Calibrated Thresholds
51
+
52
+ | Type | ITA Range | Description |
53
+ |------|-----------|-------------|
54
+ | I | > 46.86 | Very Light |
55
+ | II | 34.25 - 46.86 | Light |
56
+ | III | 20.87 - 34.25 | Intermediate |
57
+ | IV | 3.57 - 20.87 | Tan |
58
+ | V | -28.38 - 3.57 | Brown |
59
+ | VI | < -28.38 | Dark |
60
+
61
+ ## PWAT Debiasing Factors
62
+
63
+ Darker skin tones tend to overestimate tissue severity (especially Item 8: Periulcer Skin).
64
+ Correction factors reduce this bias:
65
+
66
+ | Type | P3 | P4 | P5 | P6 | P7 | P8 |
67
+ |------|-----|-----|-----|-----|-----|-----|
68
+ | I-II | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
69
+ | III | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.1 |
70
+ | IV | -0.1 | -0.1 | 0.0 | 0.0 | 0.0 | -0.3 |
71
+ | V | -0.2 | -0.2 | -0.1 | 0.0 | 0.0 | -0.6 |
72
+ | VI | -0.3 | -0.3 | -0.2 | -0.1 | 0.0 | -0.9 |
73
+
74
+ ## Dataset
75
+
76
+ - **Segmentation training**: 2,245 images (1,571 train / 336 val / 338 test)
77
+ - **Multiclass validation**: 461 images with 5-class expert masks
78
+ - **PWAT labels**: 1,321 images (920 expert GT + 401 AI consensus)
79
+ - **Fitzpatrick calibration**: 61 images with expert skin type annotations
80
+
81
+ ## Deploy to Hugging Face
82
+
83
+ ```bash
84
+ # Clone this repo to a HF Space
85
+ git clone https://huggingface.co/spaces/YOUR_USER/woundnetb7
86
+ # Copy models/ directory (or use git-lfs for large files)
87
+ git lfs install
88
+ git lfs track "*.pt" "*.pkl"
89
+ git add . && git commit -m "Initial deployment"
90
+ git push
91
+ ```
92
+
93
+ ## License
94
+
95
+ Research use only. Clinical deployment requires regulatory approval.
96
+
97
+ ## Citation
98
+
99
+ ```bibtex
100
+ @article{marquez2026woundnetb7,
101
+ title={WoundNetB7: Automated PWAT Protocol using Topological AI for Diabetic Foot Ulcers},
102
+ author={M{\'a}rquez-Sandoval, Marcelo},
103
+ year={2026}
104
+ }
105
+ ```
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio app for WoundNetB7 DFU Analysis — Hugging Face Spaces deployment.
2
+
3
+ Launch locally: python app.py
4
+ Deploy to HF: push this repo to a Hugging Face Space (GPU recommended).
5
+ """
6
+ import gradio as gr
7
+ import numpy as np
8
+ import cv2
9
+ import json
10
+ from pipeline import WoundNetB7Pipeline
11
+
12
+ # Initialize pipeline (loads all models once)
13
+ pipe = WoundNetB7Pipeline(models_dir="models", use_tta=True)
14
+
15
+ CLASS_COLORS_RGB = {
16
+ 0: (0, 0, 0),
17
+ 1: (0, 255, 0), # Foot: green
18
+ 2: (255, 165, 0), # Perilesion: orange
19
+ 3: (255, 0, 0), # Ulcer: red
20
+ }
21
+
22
+ FITZ_COLORS = {
23
+ "I": "#fef3c7", "II": "#fde68a", "III": "#fbbf24",
24
+ "IV": "#b45309", "V": "#78350f", "VI": "#451a03",
25
+ }
26
+
27
+
28
+ def analyze_image(image):
29
+ """Main analysis function called by Gradio."""
30
+ if image is None:
31
+ return None, "Please upload an image.", "{}"
32
+
33
+ # Gradio provides RGB numpy array
34
+ img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
35
+
36
+ # Run pipeline
37
+ result = pipe.analyze(img_bgr, use_tta=True)
38
+
39
+ # Create overlay visualization
40
+ overlay = pipe.visualize(img_bgr, result)
41
+
42
+ # Build text summary
43
+ summary = result.summary()
44
+
45
+ # Build JSON output
46
+ json_out = json.dumps(result.to_dict(), indent=2, ensure_ascii=False)
47
+
48
+ return overlay, summary, json_out
49
+
50
+
51
+ # Build Gradio interface
52
+ with gr.Blocks(
53
+ title="WoundNetB7 DFU Analysis",
54
+ theme=gr.themes.Soft(),
55
+ ) as demo:
56
+ gr.Markdown(
57
+ """
58
+ # WoundNetB7 — Diabetic Foot Ulcer Analysis
59
+
60
+ Upload a DFU image to get:
61
+ 1. **Multiclass segmentation** (foot / perilesion / ulcer)
62
+ 2. **Fitzpatrick skin type** via calibrated ITA (86.9% accuracy)
63
+ 3. **PWAT scores** items 3-8 with Fitzpatrick debiasing
64
+
65
+ > Model: EfficientNet-B7 + ASPP + CBAM + TAM | Ulcer Dice: 0.927
66
+ """
67
+ )
68
+
69
+ with gr.Row():
70
+ with gr.Column(scale=1):
71
+ input_image = gr.Image(label="DFU Image", type="numpy")
72
+ analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
73
+
74
+ with gr.Column(scale=1):
75
+ output_overlay = gr.Image(label="Segmentation Overlay")
76
+
77
+ with gr.Row():
78
+ with gr.Column(scale=1):
79
+ output_text = gr.Textbox(label="Analysis Summary", lines=25, max_lines=40)
80
+ with gr.Column(scale=1):
81
+ output_json = gr.Code(label="JSON Output", language="json")
82
+
83
+ analyze_btn.click(
84
+ fn=analyze_image,
85
+ inputs=[input_image],
86
+ outputs=[output_overlay, output_text, output_json],
87
+ )
88
+
89
+ gr.Markdown(
90
+ """
91
+ ---
92
+ **Legend:** Green = Foot | Orange = Perilesion | Red = Ulcer
93
+
94
+ **PWAT Items:** 3=Necrotic Type, 4=Necrotic Amount, 5=Granulation Type,
95
+ 6=Granulation Amount, 7=Edges, 8=Periulcer Skin (0=best, 4=worst)
96
+
97
+ **Debiasing:** Scores adjusted by Fitzpatrick type to reduce skin-tone bias
98
+ (18% max group gap reduction, p < 10^-27).
99
+ """
100
+ )
101
+
102
+ if __name__ == "__main__":
103
+ demo.launch(share=False)
models/pwat/xgb_pwat3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94c73bf442d9605e557d5d28578efb3b8858ffbcd8df42033685cb5a9bf5b57e
3
+ size 1428233
models/pwat/xgb_pwat4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7f74559882bb03542d9126386cec6c184cf50eb49340bbcfc2e197bed85cf44
3
+ size 1431895
models/pwat/xgb_pwat5.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70ebfd69b4f41d5ef7c33b09b456cdbc194b4f5f37f8ddffcbb0050b0c8344f3
3
+ size 1856653
models/pwat/xgb_pwat6.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69b63396d339a95eeb853daf0fc8b60f02a3c7065b8ef4f879bcab2833d6e4b4
3
+ size 1650293
models/pwat/xgb_pwat7.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d32c766bfd592d643632425fe422826551c9d093d34b241237f986980b528cc9
3
+ size 1556397
models/pwat/xgb_pwat8.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1dd0c58d57755d4459d0ceb71a2c2a8fd858735334516cba8e0d948ea5f2d9e
3
+ size 1570379
pipeline.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA.
2
+
3
+ Usage:
4
+ from pipeline import WoundNetB7Pipeline
5
+ pipe = WoundNetB7Pipeline("models/")
6
+ result = pipe.analyze("path/to/dfu_image.png")
7
+ print(result)
8
+ """
9
+ import torch
10
+ import numpy as np
11
+ import cv2
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, field, asdict
14
+ from typing import Optional
15
+
16
+ from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS
17
+ from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult
18
+ from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES
19
+
20
+
21
+ @dataclass
22
+ class AnalysisResult:
23
+ """Complete DFU analysis result from a single image."""
24
+ # Segmentation
25
+ class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
26
+ # Fitzpatrick
27
+ fitzpatrick: Optional[FitzpatrickResult] = None
28
+ # PWAT
29
+ pwat: Optional[PWATResult] = None
30
+ # Metadata
31
+ image_size: tuple = (0, 0)
32
+ device: str = "cpu"
33
+
34
+ def summary(self) -> str:
35
+ lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50]
36
+ lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}")
37
+ lines.append(f"Device: {self.device}")
38
+
39
+ lines.append("\n--- Segmentation ---")
40
+ for cls, pct in self.class_distribution.items():
41
+ lines.append(f" {cls:<15s}: {pct:5.1f}%")
42
+
43
+ if self.fitzpatrick:
44
+ f = self.fitzpatrick
45
+ lines.append("\n--- Fitzpatrick / ITA ---")
46
+ lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})")
47
+ lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}")
48
+ lines.append(f" L* mean: {f.l_skin_mean:.1f}")
49
+ lines.append(f" Confidence: {f.confidence:.2f}")
50
+ lines.append(f" Pixels: {f.healthy_pixels:,}")
51
+
52
+ if self.pwat and self.pwat.scores_raw:
53
+ p = self.pwat
54
+ lines.append("\n--- PWAT Scores (Items 3-8) ---")
55
+ lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}")
56
+ lines.append(" " + "-" * 33)
57
+ for item in [3, 4, 5, 6, 7, 8]:
58
+ name = ITEM_NAMES.get(item, f"Item {item}")
59
+ raw = p.scores_raw.get(item, "-")
60
+ adj = p.scores_adjusted.get(item, "-")
61
+ lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}")
62
+ lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}")
63
+ lines.append(f" Fitzpatrick correction: {p.fitzpatrick_type}")
64
+
65
+ return "\n".join(lines)
66
+
67
+ def to_dict(self) -> dict:
68
+ d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution}
69
+ if self.fitzpatrick:
70
+ d["fitzpatrick"] = asdict(self.fitzpatrick)
71
+ if self.pwat:
72
+ d["pwat"] = {
73
+ "scores_raw": self.pwat.scores_raw,
74
+ "scores_adjusted": self.pwat.scores_adjusted,
75
+ "total_raw": self.pwat.total_raw,
76
+ "total_adjusted": self.pwat.total_adjusted,
77
+ }
78
+ return d
79
+
80
+
81
+ class WoundNetB7Pipeline:
82
+ """End-to-end DFU analysis pipeline.
83
+
84
+ Args:
85
+ models_dir: Path to models/ directory containing:
86
+ - segmentation/WoundNetB7_proposed_best.pt
87
+ - pwat/xgb_pwat{3-8}.pkl
88
+ device: "cuda" or "cpu" (auto-detected if None)
89
+ use_tta: Use test-time augmentation for segmentation (slower but more accurate)
90
+ """
91
+
92
+ def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True):
93
+ self.models_dir = Path(models_dir)
94
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
95
+ self.use_tta = use_tta
96
+
97
+ # Load segmentation model
98
+ seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt"
99
+ self.seg_model = load_segmentation_model(str(seg_path), self.device)
100
+ print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)")
101
+
102
+ # Load PWAT predictor
103
+ pwat_path = self.models_dir / "pwat"
104
+ self.pwat_predictor = PWATPredictor(str(pwat_path))
105
+ print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)")
106
+
107
+ print(f"Pipeline ready on {self.device}")
108
+
109
+ def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult:
110
+ """Analyze a DFU image end-to-end.
111
+
112
+ Args:
113
+ image_input: file path (str/Path), BGR numpy array, or RGB numpy array
114
+ use_tta: Override TTA setting for this call
115
+
116
+ Returns:
117
+ AnalysisResult with segmentation, Fitzpatrick, and PWAT data.
118
+ """
119
+ tta = use_tta if use_tta is not None else self.use_tta
120
+
121
+ # Load image
122
+ if isinstance(image_input, (str, Path)):
123
+ img_bgr = cv2.imread(str(image_input))
124
+ if img_bgr is None:
125
+ raise FileNotFoundError(f"Cannot read image: {image_input}")
126
+ elif isinstance(image_input, np.ndarray):
127
+ img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
128
+ else:
129
+ raise TypeError(f"Unsupported input type: {type(image_input)}")
130
+
131
+ h, w = img_bgr.shape[:2]
132
+
133
+ # Step 1: Segmentation
134
+ seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta)
135
+ classmap = seg["classmap"]
136
+
137
+ class_dist = {}
138
+ for cid, name in CLASS_NAMES.items():
139
+ class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1)
140
+
141
+ # Step 2: Fitzpatrick estimation (from healthy skin)
142
+ fitz = estimate_fitzpatrick(img_bgr, seg["masks"])
143
+
144
+ # Step 3: PWAT prediction (from ulcer mask)
145
+ ulcer_mask = (classmap == 3).astype(np.uint8) * 255
146
+ pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type)
147
+
148
+ return AnalysisResult(
149
+ class_distribution=class_dist,
150
+ fitzpatrick=fitz,
151
+ pwat=pwat,
152
+ image_size=(h, w),
153
+ device=str(self.device),
154
+ )
155
+
156
+ def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
157
+ """Create overlay visualization of segmentation result."""
158
+ h, w = img_bgr.shape[:2]
159
+ seg = segment(self.seg_model, img_bgr, self.device, use_tta=False)
160
+ classmap = seg["classmap"]
161
+
162
+ overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
163
+ for cid, color in CLASS_COLORS.items():
164
+ if cid == 0:
165
+ continue
166
+ mask = classmap == cid
167
+ overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
168
+
169
+ return overlay.astype(np.uint8)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ segmentation-models-pytorch>=0.4.0
4
+ timm>=1.0.3
5
+ opencv-python-headless>=4.8.0
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ scikit-learn>=1.3.0
9
+ xgboost>=2.0.0
10
+ joblib>=1.3.0
11
+ gradio>=4.0.0
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """WoundNetB7 DFU Analysis Pipeline — Segmentation + PWAT + Fitzpatrick/ITA."""
src/fitzpatrick_estimator.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fitzpatrick skin type estimation via ITA (Individual Typology Angle).
2
+
3
+ Calibrated on 61 DFU images with expert ground truth.
4
+ Validation: 86.9% exact match, 98.4% adjacent, r=0.975.
5
+ """
6
+ import numpy as np
7
+ import cv2
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ # Calibrated ITA thresholds for DFU clinical photography (61-image validation)
12
+ ITA_THRESHOLDS = {
13
+ "I": (46.86, float("inf")),
14
+ "II": (34.25, 46.86),
15
+ "III": (20.87, 34.25),
16
+ "IV": (3.57, 20.87),
17
+ "V": (-28.38, 3.57),
18
+ "VI": (float("-inf"), -28.38),
19
+ }
20
+
21
+ FITZPATRICK_LABELS = {
22
+ "I": "Very Light", "II": "Light", "III": "Intermediate",
23
+ "IV": "Tan", "V": "Brown", "VI": "Dark",
24
+ }
25
+
26
+
27
+ @dataclass
28
+ class FitzpatrickResult:
29
+ fitzpatrick_type: str # "I" .. "VI"
30
+ fitzpatrick_int: int # 1 .. 6
31
+ fitzpatrick_label: str # "Very Light" .. "Dark"
32
+ ita_angle: float # ITA in degrees
33
+ ita_std: float # ITA standard deviation
34
+ l_skin_mean: float # Mean L* of healthy skin
35
+ b_skin_mean: float # Mean b* of healthy skin
36
+ healthy_pixels: int # Number of healthy skin pixels used
37
+ healthy_ratio: float # Healthy pixels / total image pixels
38
+ confidence: float # 0-1 confidence score
39
+
40
+
41
+ def compute_ita(l_values: np.ndarray, b_values: np.ndarray) -> tuple:
42
+ """Compute ITA angle from L* and b* values with robust trimming.
43
+
44
+ ITA = arctan((L* - 50) / b*) * (180 / pi)
45
+ Higher ITA = lighter skin, lower ITA = darker skin.
46
+ """
47
+ ita_per_pixel = np.degrees(np.arctan2(l_values - 50.0, b_values))
48
+ # Robust trimming: 5th-95th percentile
49
+ p5, p95 = np.percentile(ita_per_pixel, [5, 95])
50
+ trimmed = ita_per_pixel[(ita_per_pixel >= p5) & (ita_per_pixel <= p95)]
51
+ if len(trimmed) < 10:
52
+ trimmed = ita_per_pixel
53
+ return float(np.mean(trimmed)), float(np.std(trimmed))
54
+
55
+
56
+ def classify_fitzpatrick(ita: float) -> tuple:
57
+ """Classify ITA angle into Fitzpatrick type using calibrated DFU thresholds."""
58
+ for ftype, (lo, hi) in ITA_THRESHOLDS.items():
59
+ if lo <= ita < hi:
60
+ idx = list(ITA_THRESHOLDS.keys()).index(ftype) + 1
61
+ return ftype, idx
62
+ return "III", 3 # Default fallback
63
+
64
+
65
+ def estimate_fitzpatrick(
66
+ img_bgr: np.ndarray,
67
+ masks: dict,
68
+ periulcer_dilation_px: int = 40,
69
+ ) -> FitzpatrickResult:
70
+ """Estimate Fitzpatrick type from a DFU image using segmentation masks.
71
+
72
+ Strategy: Healthy skin = foot region - perilesion zone - ulcer.
73
+ ITA is computed on the healthy skin pixels only.
74
+
75
+ Args:
76
+ img_bgr: BGR image (H, W, 3)
77
+ masks: dict with keys 'foot', 'perilesion', 'ulcer' (bool arrays H, W)
78
+ periulcer_dilation_px: Extra dilation around wound for safety margin
79
+ """
80
+ h, w = img_bgr.shape[:2]
81
+ foot = masks.get("foot", np.ones((h, w), dtype=bool))
82
+ peri = masks.get("perilesion", np.zeros((h, w), dtype=bool))
83
+ ulcer = masks.get("ulcer", np.zeros((h, w), dtype=bool))
84
+
85
+ # Dilate ulcer+perilesion for safety margin
86
+ exclusion = (peri | ulcer).astype(np.uint8)
87
+ if periulcer_dilation_px > 0:
88
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (periulcer_dilation_px, periulcer_dilation_px))
89
+ exclusion = cv2.dilate(exclusion, kernel)
90
+ exclusion = exclusion.astype(bool)
91
+
92
+ # Healthy skin = foot minus exclusion zone
93
+ healthy = foot & ~exclusion
94
+ healthy_pixels = int(np.sum(healthy))
95
+
96
+ if healthy_pixels < 100:
97
+ # Fallback: use all foot pixels minus wound
98
+ healthy = foot & ~ulcer
99
+ healthy_pixels = int(np.sum(healthy))
100
+
101
+ if healthy_pixels < 50:
102
+ return FitzpatrickResult(
103
+ fitzpatrick_type="III", fitzpatrick_int=3,
104
+ fitzpatrick_label="Intermediate",
105
+ ita_angle=0.0, ita_std=0.0,
106
+ l_skin_mean=0.0, b_skin_mean=0.0,
107
+ healthy_pixels=healthy_pixels,
108
+ healthy_ratio=healthy_pixels / (h * w),
109
+ confidence=0.0,
110
+ )
111
+
112
+ # Convert to L*a*b*
113
+ lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
114
+ l_values = lab[healthy, 0] * (100.0 / 255.0) # OpenCV L* is 0-255 -> 0-100
115
+ b_values = lab[healthy, 2] - 128.0 # OpenCV b* is 0-255 -> -128 to +127
116
+
117
+ ita_mean, ita_std = compute_ita(l_values, b_values)
118
+ ftype, fint = classify_fitzpatrick(ita_mean)
119
+
120
+ # Confidence: pixel count + ITA consistency + coverage
121
+ pixel_conf = min(healthy_pixels / 5000.0, 1.0)
122
+ ita_conf = max(0.0, 1.0 - (ita_std / 30.0))
123
+ coverage_conf = min((healthy_pixels / (h * w)) / 0.15, 1.0)
124
+ confidence = pixel_conf * 0.3 + ita_conf * 0.4 + coverage_conf * 0.3
125
+
126
+ return FitzpatrickResult(
127
+ fitzpatrick_type=ftype,
128
+ fitzpatrick_int=fint,
129
+ fitzpatrick_label=FITZPATRICK_LABELS[ftype],
130
+ ita_angle=round(ita_mean, 2),
131
+ ita_std=round(ita_std, 2),
132
+ l_skin_mean=round(float(np.mean(l_values)), 2),
133
+ b_skin_mean=round(float(np.mean(b_values)), 2),
134
+ healthy_pixels=healthy_pixels,
135
+ healthy_ratio=round(healthy_pixels / (h * w), 4),
136
+ confidence=round(confidence, 3),
137
+ )
src/pwat_estimator.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8.
2
+
3
+ Uses segmentation masks to extract color, tissue, morphological, and texture
4
+ features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers.
5
+
6
+ Includes Fitzpatrick-aware debiasing correction factors.
7
+ """
8
+ import numpy as np
9
+ import cv2
10
+ import joblib
11
+ from dataclasses import dataclass, field
12
+ from typing import Optional
13
+ from pathlib import Path
14
+
15
+ ITEMS = [3, 4, 5, 6, 7, 8]
16
+ ITEM_NAMES = {
17
+ 3: "Necrotic Type",
18
+ 4: "Necrotic Amount",
19
+ 5: "Granulation Type",
20
+ 6: "Granulation Amount",
21
+ 7: "Edges",
22
+ 8: "Periulcer Skin",
23
+ }
24
+
25
+ # Debiasing correction factors (calibrated from 61 DFU images)
26
+ # Applied as: adjusted = clip(raw + factor, 0, 4)
27
+ CORRECTION_FACTORS = {
28
+ "I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
29
+ "II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
30
+ "III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1},
31
+ "IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3},
32
+ "V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6},
33
+ "VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9},
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class PWATResult:
39
+ scores_raw: dict = field(default_factory=dict) # {item: int}
40
+ scores_adjusted: dict = field(default_factory=dict) # {item: float} (debiased)
41
+ total_raw: int = 0
42
+ total_adjusted: float = 0.0
43
+ fitzpatrick_type: str = ""
44
+ features: dict = field(default_factory=dict)
45
+
46
+
47
+ def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
48
+ """Extract 63 features from the wound region for PWAT prediction.
49
+
50
+ Features: color (RGB/HSV/Lab), tissue composition, morphology, texture.
51
+ """
52
+ b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
53
+ npx = int(np.sum(b))
54
+ if npx < 50:
55
+ return None
56
+
57
+ feats = {}
58
+
59
+ # --- Color features (45) ---
60
+ hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
61
+ lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
62
+ rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
63
+
64
+ for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]:
65
+ for ci, cn in enumerate(names):
66
+ vals = arr[b, ci]
67
+ feats[f"{cs}_{cn}_mean"] = float(np.mean(vals))
68
+ feats[f"{cs}_{cn}_std"] = float(np.std(vals))
69
+ feats[f"{cs}_{cn}_median"] = float(np.median(vals))
70
+ feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
71
+ feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
72
+
73
+ # --- Tissue composition (5) ---
74
+ h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
75
+ l_ch = lab[b, 0] * (100 / 255)
76
+ a_ch = lab[b, 1] - 128
77
+
78
+ eschar = ((v < 100) & (s < 60)) | (v < 60)
79
+ slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar
80
+ gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar
81
+ necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran
82
+
83
+ feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100)
84
+ feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100)
85
+ feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100)
86
+ feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
87
+ feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
88
+
89
+ # --- Morphological features (7) ---
90
+ mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
91
+ cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
92
+ if cnts:
93
+ cnt = max(cnts, key=cv2.contourArea)
94
+ area = cv2.contourArea(cnt)
95
+ perim = cv2.arcLength(cnt, True)
96
+ circ = 4 * np.pi * area / (perim**2) if perim > 0 else 0
97
+ feats["morph_area"] = float(area)
98
+ feats["morph_perimeter"] = float(perim)
99
+ feats["morph_circularity"] = float(circ)
100
+ feats["morph_irregularity"] = float(1 - circ)
101
+ x, y, w2, h2 = cv2.boundingRect(cnt)
102
+ feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8))
103
+ feats["morph_extent"] = float(area / (w2 * h2 + 1e-8))
104
+ hull = cv2.convexHull(cnt)
105
+ feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
106
+
107
+ # --- Texture features (4) ---
108
+ gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
109
+ wound_gray = gray[b]
110
+ feats["texture_mean"] = float(np.mean(wound_gray))
111
+ feats["texture_std"] = float(np.std(wound_gray))
112
+ hist_vals = np.histogram(wound_gray, bins=64, density=True)[0]
113
+ feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10)))
114
+
115
+ dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8))
116
+ eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8))
117
+ edge_zone = (dilated - eroded) > 127
118
+ if np.any(edge_zone):
119
+ feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
120
+
121
+ # --- ROI features (2) ---
122
+ feats["wound_npx"] = float(npx)
123
+ feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
124
+
125
+ return feats
126
+
127
+
128
+ class PWATPredictor:
129
+ """Predicts PWAT items 3-8 from wound features using trained XGBoost models."""
130
+
131
+ def __init__(self, models_dir: str):
132
+ self.models = {}
133
+ models_path = Path(models_dir)
134
+ for item in ITEMS:
135
+ pkl = models_path / f"xgb_pwat{item}.pkl"
136
+ if pkl.exists():
137
+ self.models[item] = joblib.load(pkl)
138
+
139
+ def predict(
140
+ self,
141
+ img_bgr: np.ndarray,
142
+ ulcer_mask: np.ndarray,
143
+ fitzpatrick_type: str = "III",
144
+ feature_cols: Optional[list] = None,
145
+ ) -> PWATResult:
146
+ """Predict PWAT scores for a single image.
147
+
148
+ Args:
149
+ img_bgr: BGR image
150
+ ulcer_mask: Binary ulcer mask (H, W)
151
+ fitzpatrick_type: Fitzpatrick type for debiasing ("I" .. "VI")
152
+ feature_cols: Ordered feature column names (must match training order).
153
+ If None, uses all extracted features sorted alphabetically.
154
+ """
155
+ feats = extract_features(img_bgr, ulcer_mask)
156
+ if feats is None:
157
+ return PWATResult(fitzpatrick_type=fitzpatrick_type)
158
+
159
+ # Build feature vector
160
+ if feature_cols is None:
161
+ feature_cols = sorted(feats.keys())
162
+ X = np.array([[feats.get(c, 0.0) for c in feature_cols]])
163
+
164
+ scores_raw = {}
165
+ scores_adj = {}
166
+ for item in ITEMS:
167
+ if item in self.models:
168
+ pred = int(self.models[item].predict(X)[0])
169
+ scores_raw[item] = pred
170
+ factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
171
+ scores_adj[item] = float(np.clip(pred + factor, 0, 4))
172
+ else:
173
+ scores_raw[item] = 0
174
+ scores_adj[item] = 0.0
175
+
176
+ return PWATResult(
177
+ scores_raw=scores_raw,
178
+ scores_adjusted=scores_adj,
179
+ total_raw=sum(scores_raw.values()),
180
+ total_adjusted=round(sum(scores_adj.values()), 1),
181
+ fitzpatrick_type=fitzpatrick_type,
182
+ features=feats,
183
+ )
src/segmentation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer).
2
+
3
+ Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder.
4
+ Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]).
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import segmentation_models_pytorch as smp
10
+ import numpy as np
11
+ import cv2
12
+ from pathlib import Path
13
+
14
+ IMG_SIZE = 512
15
+ MEAN = np.array([0.485, 0.456, 0.406])
16
+ STD = np.array([0.229, 0.224, 0.225])
17
+ CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"}
18
+ CLASS_COLORS = {
19
+ 0: (0, 0, 0),
20
+ 1: (0, 255, 0),
21
+ 2: (255, 165, 0),
22
+ 3: (255, 0, 0),
23
+ }
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Architecture modules (match checkpoint weights exactly)
28
+ # ---------------------------------------------------------------------------
29
+
30
+ class ChannelAttention(nn.Module):
31
+ def __init__(self, channels, reduction=16):
32
+ super().__init__()
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(channels, channels // reduction, bias=False),
35
+ nn.ReLU(inplace=True),
36
+ nn.Linear(channels // reduction, channels, bias=False),
37
+ )
38
+
39
+ def forward(self, x):
40
+ avg_out = self.mlp(x.mean(dim=[2, 3]))
41
+ max_out = self.mlp(x.amax(dim=[2, 3]))
42
+ attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
43
+ return x * attn
44
+
45
+
46
+ class SpatialAttention(nn.Module):
47
+ def __init__(self, kernel_size=7):
48
+ super().__init__()
49
+ self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
50
+
51
+ def forward(self, x):
52
+ avg_out = x.mean(dim=1, keepdim=True)
53
+ max_out = x.amax(dim=1, keepdim=True)
54
+ attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
55
+ return x * attn
56
+
57
+
58
+ class CBAM(nn.Module):
59
+ def __init__(self, channels, reduction=16, kernel_size=7):
60
+ super().__init__()
61
+ self.ca = ChannelAttention(channels, reduction)
62
+ self.sa = SpatialAttention(kernel_size)
63
+
64
+ def forward(self, x):
65
+ return self.sa(self.ca(x))
66
+
67
+
68
+ class DifferentiableFractalDimension(nn.Module):
69
+ def __init__(self, scales=None):
70
+ super().__init__()
71
+ self.scales = scales or [2, 4, 8, 16, 32]
72
+
73
+ def forward(self, x):
74
+ B, C, H, W = x.shape
75
+ counts = []
76
+ for s in self.scales:
77
+ if s >= H or s >= W:
78
+ continue
79
+ pooled = F.avg_pool2d(x, kernel_size=s, stride=s)
80
+ n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3])
81
+ counts.append(n_boxes)
82
+ if len(counts) < 2:
83
+ return torch.ones(B, C, device=x.device)
84
+ log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device))
85
+ log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1)
86
+ n = log_s.shape[0]
87
+ sx, sxx = log_s.sum(), (log_s**2).sum()
88
+ sy = log_c.sum(dim=-1)
89
+ sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
90
+ slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8)
91
+ return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1)
92
+
93
+
94
+ class DifferentiableEulerCharacteristic(nn.Module):
95
+ def forward(self, x):
96
+ B, C, H, W = x.shape
97
+ b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5))
98
+ V = b.sum(dim=[2, 3])
99
+ E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3])
100
+ E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3])
101
+ F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3])
102
+ euler = V - E_h - E_v + F_val
103
+ return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W)
104
+
105
+
106
+ class TopologicalAttentionModule(nn.Module):
107
+ def __init__(self, in_channels):
108
+ super().__init__()
109
+ self.fractal = DifferentiableFractalDimension()
110
+ self.euler = DifferentiableEulerCharacteristic()
111
+ self.alpha = nn.Parameter(torch.tensor(1.0))
112
+ self.beta = nn.Parameter(torch.tensor(1.0))
113
+ self.conv = nn.Sequential(
114
+ nn.Conv2d(in_channels + 2, in_channels, 1),
115
+ nn.BatchNorm2d(in_channels),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(in_channels, in_channels, 1),
118
+ nn.Sigmoid(),
119
+ )
120
+
121
+ def forward(self, x):
122
+ B, C, H, W = x.shape
123
+ fm = self.fractal(x).expand(B, 1, H, W)
124
+ em = self.euler(x).expand(B, 1, H, W)
125
+ attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1))
126
+ return x * attn + x
127
+
128
+
129
+ class ASPP(nn.Module):
130
+ def __init__(self, in_ch, out_ch, rates=None):
131
+ super().__init__()
132
+ rates = rates or [6, 12, 18]
133
+ self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True))
134
+ self.atrous = nn.ModuleList(
135
+ [nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates]
136
+ )
137
+ self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True))
138
+ self.project = nn.Sequential(
139
+ nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5)
140
+ )
141
+
142
+ def forward(self, x):
143
+ size = x.shape[2:]
144
+ feats = [self.conv1x1(x)] + [a(x) for a in self.atrous]
145
+ feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False))
146
+ return self.project(torch.cat(feats, dim=1))
147
+
148
+
149
+ class WoundNetB7(nn.Module):
150
+ """WoundNetB7 matching the Track B checkpoint structure."""
151
+
152
+ NUM_CLASSES = 4
153
+
154
+ def __init__(self, num_classes=4):
155
+ super().__init__()
156
+ self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes)
157
+ enc_ch = self.backbone.encoder.out_channels[-1]
158
+ self.aspp = ASPP(enc_ch, enc_ch)
159
+ self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2))
160
+ self.tam = TopologicalAttentionModule(num_classes)
161
+ self.diffusion_weight = nn.Parameter(torch.tensor(0.01))
162
+
163
+ def forward(self, x):
164
+ features = list(self.backbone.encoder(x))
165
+ features[-1] = self.aspp(features[-1])
166
+ try:
167
+ dec = self.backbone.decoder(features)
168
+ except TypeError:
169
+ dec = self.backbone.decoder(*features)
170
+ seg = self.backbone.segmentation_head(dec)
171
+ seg = self.cbam(seg)
172
+ seg = self.tam(seg)
173
+ return seg
174
+
175
+
176
+ # ---------------------------------------------------------------------------
177
+ # Inference helpers
178
+ # ---------------------------------------------------------------------------
179
+
180
+ def preprocess(img_bgr: np.ndarray) -> torch.Tensor:
181
+ """BGR image -> normalized CHW tensor (1, 3, 512, 512)."""
182
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
183
+ img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
184
+ img = (img.astype(np.float32) / 255.0 - MEAN) / STD
185
+ return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float()
186
+
187
+
188
+ def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
189
+ """6-fold TTA -> averaged softmax probabilities (1, C, H, W)."""
190
+ transforms = [
191
+ lambda x: x,
192
+ lambda x: torch.flip(x, [3]),
193
+ lambda x: torch.flip(x, [2]),
194
+ lambda x: torch.rot90(x, 1, [2, 3]),
195
+ lambda x: torch.rot90(x, 2, [2, 3]),
196
+ lambda x: torch.rot90(x, 3, [2, 3]),
197
+ ]
198
+ inverse = [
199
+ lambda x: x,
200
+ lambda x: torch.flip(x, [3]),
201
+ lambda x: torch.flip(x, [2]),
202
+ lambda x: torch.rot90(x, 3, [2, 3]),
203
+ lambda x: torch.rot90(x, 2, [2, 3]),
204
+ lambda x: torch.rot90(x, 1, [2, 3]),
205
+ ]
206
+ probs_sum = None
207
+ with torch.no_grad():
208
+ for tfm, inv in zip(transforms, inverse):
209
+ out = model(tfm(img_tensor).to(device))
210
+ if isinstance(out, (tuple, list)):
211
+ out = out[0]
212
+ if isinstance(out, dict):
213
+ out = out["seg"]
214
+ p = inv(F.softmax(out, dim=1))
215
+ probs_sum = p if probs_sum is None else probs_sum + p
216
+ return probs_sum / len(transforms)
217
+
218
+
219
+ def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module:
220
+ """Load WoundNetB7 from checkpoint."""
221
+ model = WoundNetB7(num_classes=4)
222
+ state = torch.load(checkpoint_path, map_location=device, weights_only=False)
223
+ # Remove PWAT head keys if present
224
+ state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")}
225
+ model.load_state_dict(state, strict=False)
226
+ model.to(device).eval()
227
+ return model
228
+
229
+
230
+ def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict:
231
+ """Run segmentation on a BGR image.
232
+
233
+ Returns dict with:
234
+ classmap: (H, W) uint8 with class indices 0-3
235
+ masks: dict of per-class binary masks {cls_name: (H, W) bool}
236
+ probs: (4, H, W) float32 softmax probabilities
237
+ """
238
+ h, w = img_bgr.shape[:2]
239
+ tensor = preprocess(img_bgr)
240
+
241
+ if use_tta:
242
+ probs = tta_inference(model, tensor, device)
243
+ else:
244
+ with torch.no_grad():
245
+ out = model(tensor.to(device))
246
+ if isinstance(out, (tuple, list)):
247
+ out = out[0]
248
+ if isinstance(out, dict):
249
+ out = out["seg"]
250
+ probs = F.softmax(out, dim=1)
251
+
252
+ probs_np = probs[0].cpu().numpy()
253
+ probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
254
+ classmap = probs_resized.argmax(axis=0).astype(np.uint8)
255
+ masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
256
+ return {"classmap": classmap, "masks": masks, "probs": probs_resized}