Add pipeline code, PWAT models, and Gradio app
Browse files- README.md +105 -12
- app.py +103 -0
- models/pwat/xgb_pwat3.pkl +3 -0
- models/pwat/xgb_pwat4.pkl +3 -0
- models/pwat/xgb_pwat5.pkl +3 -0
- models/pwat/xgb_pwat6.pkl +3 -0
- models/pwat/xgb_pwat7.pkl +3 -0
- models/pwat/xgb_pwat8.pkl +3 -0
- pipeline.py +169 -0
- requirements.txt +11 -0
- src/__init__.py +1 -0
- src/fitzpatrick_estimator.py +137 -0
- src/pwat_estimator.py +183 -0
- src/segmentation.py +256 -0
README.md
CHANGED
|
@@ -1,12 +1,105 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WoundNetB7 — Automated DFU Assessment Pipeline
|
| 2 |
+
|
| 3 |
+
End-to-end pipeline for Diabetic Foot Ulcer (DFU) analysis:
|
| 4 |
+
|
| 5 |
+
**Image** -> **Multiclass Segmentation** -> **PWAT Scoring** + **Fitzpatrick/ITA Estimation**
|
| 6 |
+
|
| 7 |
+
## Pipeline
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
Input Image (DFU photograph)
|
| 11 |
+
|
|
| 12 |
+
v
|
| 13 |
+
[1] WoundNetB7 Segmentation (EfficientNet-B7 + ASPP + CBAM + TAM)
|
| 14 |
+
-> 4-class masks: background, foot, perilesion, ulcer
|
| 15 |
+
-> Ulcer Dice: 0.927 (95% CI: [0.917, 0.936])
|
| 16 |
+
|
|
| 17 |
+
+---> [2] PWAT Estimation (XGBoost per item)
|
| 18 |
+
| -> Items 3-8 scores (0-4 ordinal)
|
| 19 |
+
| -> 5-fold CV: MAE 0.61-0.87, Adjacent Match 75-89%
|
| 20 |
+
|
|
| 21 |
+
+---> [3] Fitzpatrick/ITA Estimation
|
| 22 |
+
| -> Healthy skin = foot - perilesion - ulcer
|
| 23 |
+
| -> ITA angle + Fitzpatrick I-VI classification
|
| 24 |
+
| -> Calibrated on 61 DFU images (86.9% exact match)
|
| 25 |
+
|
|
| 26 |
+
+---> [4] Debiasing
|
| 27 |
+
-> PWAT scores adjusted by Fitzpatrick type
|
| 28 |
+
-> 18% max group gap reduction (p < 10^-27)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Quick Start
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
from pipeline import WoundNetB7Pipeline
|
| 35 |
+
|
| 36 |
+
pipe = WoundNetB7Pipeline("models/")
|
| 37 |
+
result = pipe.analyze("path/to/dfu_image.png")
|
| 38 |
+
print(result.summary())
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Models
|
| 42 |
+
|
| 43 |
+
| Component | Architecture | Metric | Value |
|
| 44 |
+
|-----------|-------------|--------|-------|
|
| 45 |
+
| Segmentation | EfficientNet-B7 + UNet + ASPP + CBAM + TAM | Ulcer Dice | 0.927 |
|
| 46 |
+
| PWAT Item 3 | XGBoost + SMOTE | MAE / Adjacent | 0.87 / 74.9% |
|
| 47 |
+
| PWAT Item 7 | XGBoost + SMOTE | MAE / Adjacent | 0.61 / 89.3% |
|
| 48 |
+
| Fitzpatrick | ITA calibrated thresholds | Exact match | 86.9% |
|
| 49 |
+
|
| 50 |
+
## Fitzpatrick/ITA Calibrated Thresholds
|
| 51 |
+
|
| 52 |
+
| Type | ITA Range | Description |
|
| 53 |
+
|------|-----------|-------------|
|
| 54 |
+
| I | > 46.86 | Very Light |
|
| 55 |
+
| II | 34.25 - 46.86 | Light |
|
| 56 |
+
| III | 20.87 - 34.25 | Intermediate |
|
| 57 |
+
| IV | 3.57 - 20.87 | Tan |
|
| 58 |
+
| V | -28.38 - 3.57 | Brown |
|
| 59 |
+
| VI | < -28.38 | Dark |
|
| 60 |
+
|
| 61 |
+
## PWAT Debiasing Factors
|
| 62 |
+
|
| 63 |
+
Darker skin tones tend to overestimate tissue severity (especially Item 8: Periulcer Skin).
|
| 64 |
+
Correction factors reduce this bias:
|
| 65 |
+
|
| 66 |
+
| Type | P3 | P4 | P5 | P6 | P7 | P8 |
|
| 67 |
+
|------|-----|-----|-----|-----|-----|-----|
|
| 68 |
+
| I-II | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
|
| 69 |
+
| III | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | -0.1 |
|
| 70 |
+
| IV | -0.1 | -0.1 | 0.0 | 0.0 | 0.0 | -0.3 |
|
| 71 |
+
| V | -0.2 | -0.2 | -0.1 | 0.0 | 0.0 | -0.6 |
|
| 72 |
+
| VI | -0.3 | -0.3 | -0.2 | -0.1 | 0.0 | -0.9 |
|
| 73 |
+
|
| 74 |
+
## Dataset
|
| 75 |
+
|
| 76 |
+
- **Segmentation training**: 2,245 images (1,571 train / 336 val / 338 test)
|
| 77 |
+
- **Multiclass validation**: 461 images with 5-class expert masks
|
| 78 |
+
- **PWAT labels**: 1,321 images (920 expert GT + 401 AI consensus)
|
| 79 |
+
- **Fitzpatrick calibration**: 61 images with expert skin type annotations
|
| 80 |
+
|
| 81 |
+
## Deploy to Hugging Face
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
# Clone this repo to a HF Space
|
| 85 |
+
git clone https://huggingface.co/spaces/YOUR_USER/woundnetb7
|
| 86 |
+
# Copy models/ directory (or use git-lfs for large files)
|
| 87 |
+
git lfs install
|
| 88 |
+
git lfs track "*.pt" "*.pkl"
|
| 89 |
+
git add . && git commit -m "Initial deployment"
|
| 90 |
+
git push
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## License
|
| 94 |
+
|
| 95 |
+
Research use only. Clinical deployment requires regulatory approval.
|
| 96 |
+
|
| 97 |
+
## Citation
|
| 98 |
+
|
| 99 |
+
```bibtex
|
| 100 |
+
@article{marquez2026woundnetb7,
|
| 101 |
+
title={WoundNetB7: Automated PWAT Protocol using Topological AI for Diabetic Foot Ulcers},
|
| 102 |
+
author={M{\'a}rquez-Sandoval, Marcelo},
|
| 103 |
+
year={2026}
|
| 104 |
+
}
|
| 105 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gradio app for WoundNetB7 DFU Analysis — Hugging Face Spaces deployment.
|
| 2 |
+
|
| 3 |
+
Launch locally: python app.py
|
| 4 |
+
Deploy to HF: push this repo to a Hugging Face Space (GPU recommended).
|
| 5 |
+
"""
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import numpy as np
|
| 8 |
+
import cv2
|
| 9 |
+
import json
|
| 10 |
+
from pipeline import WoundNetB7Pipeline
|
| 11 |
+
|
| 12 |
+
# Initialize pipeline (loads all models once)
|
| 13 |
+
pipe = WoundNetB7Pipeline(models_dir="models", use_tta=True)
|
| 14 |
+
|
| 15 |
+
CLASS_COLORS_RGB = {
|
| 16 |
+
0: (0, 0, 0),
|
| 17 |
+
1: (0, 255, 0), # Foot: green
|
| 18 |
+
2: (255, 165, 0), # Perilesion: orange
|
| 19 |
+
3: (255, 0, 0), # Ulcer: red
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
FITZ_COLORS = {
|
| 23 |
+
"I": "#fef3c7", "II": "#fde68a", "III": "#fbbf24",
|
| 24 |
+
"IV": "#b45309", "V": "#78350f", "VI": "#451a03",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def analyze_image(image):
|
| 29 |
+
"""Main analysis function called by Gradio."""
|
| 30 |
+
if image is None:
|
| 31 |
+
return None, "Please upload an image.", "{}"
|
| 32 |
+
|
| 33 |
+
# Gradio provides RGB numpy array
|
| 34 |
+
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 35 |
+
|
| 36 |
+
# Run pipeline
|
| 37 |
+
result = pipe.analyze(img_bgr, use_tta=True)
|
| 38 |
+
|
| 39 |
+
# Create overlay visualization
|
| 40 |
+
overlay = pipe.visualize(img_bgr, result)
|
| 41 |
+
|
| 42 |
+
# Build text summary
|
| 43 |
+
summary = result.summary()
|
| 44 |
+
|
| 45 |
+
# Build JSON output
|
| 46 |
+
json_out = json.dumps(result.to_dict(), indent=2, ensure_ascii=False)
|
| 47 |
+
|
| 48 |
+
return overlay, summary, json_out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# Build Gradio interface
|
| 52 |
+
with gr.Blocks(
|
| 53 |
+
title="WoundNetB7 DFU Analysis",
|
| 54 |
+
theme=gr.themes.Soft(),
|
| 55 |
+
) as demo:
|
| 56 |
+
gr.Markdown(
|
| 57 |
+
"""
|
| 58 |
+
# WoundNetB7 — Diabetic Foot Ulcer Analysis
|
| 59 |
+
|
| 60 |
+
Upload a DFU image to get:
|
| 61 |
+
1. **Multiclass segmentation** (foot / perilesion / ulcer)
|
| 62 |
+
2. **Fitzpatrick skin type** via calibrated ITA (86.9% accuracy)
|
| 63 |
+
3. **PWAT scores** items 3-8 with Fitzpatrick debiasing
|
| 64 |
+
|
| 65 |
+
> Model: EfficientNet-B7 + ASPP + CBAM + TAM | Ulcer Dice: 0.927
|
| 66 |
+
"""
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with gr.Row():
|
| 70 |
+
with gr.Column(scale=1):
|
| 71 |
+
input_image = gr.Image(label="DFU Image", type="numpy")
|
| 72 |
+
analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
|
| 73 |
+
|
| 74 |
+
with gr.Column(scale=1):
|
| 75 |
+
output_overlay = gr.Image(label="Segmentation Overlay")
|
| 76 |
+
|
| 77 |
+
with gr.Row():
|
| 78 |
+
with gr.Column(scale=1):
|
| 79 |
+
output_text = gr.Textbox(label="Analysis Summary", lines=25, max_lines=40)
|
| 80 |
+
with gr.Column(scale=1):
|
| 81 |
+
output_json = gr.Code(label="JSON Output", language="json")
|
| 82 |
+
|
| 83 |
+
analyze_btn.click(
|
| 84 |
+
fn=analyze_image,
|
| 85 |
+
inputs=[input_image],
|
| 86 |
+
outputs=[output_overlay, output_text, output_json],
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
gr.Markdown(
|
| 90 |
+
"""
|
| 91 |
+
---
|
| 92 |
+
**Legend:** Green = Foot | Orange = Perilesion | Red = Ulcer
|
| 93 |
+
|
| 94 |
+
**PWAT Items:** 3=Necrotic Type, 4=Necrotic Amount, 5=Granulation Type,
|
| 95 |
+
6=Granulation Amount, 7=Edges, 8=Periulcer Skin (0=best, 4=worst)
|
| 96 |
+
|
| 97 |
+
**Debiasing:** Scores adjusted by Fitzpatrick type to reduce skin-tone bias
|
| 98 |
+
(18% max group gap reduction, p < 10^-27).
|
| 99 |
+
"""
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
demo.launch(share=False)
|
models/pwat/xgb_pwat3.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94c73bf442d9605e557d5d28578efb3b8858ffbcd8df42033685cb5a9bf5b57e
|
| 3 |
+
size 1428233
|
models/pwat/xgb_pwat4.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7f74559882bb03542d9126386cec6c184cf50eb49340bbcfc2e197bed85cf44
|
| 3 |
+
size 1431895
|
models/pwat/xgb_pwat5.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70ebfd69b4f41d5ef7c33b09b456cdbc194b4f5f37f8ddffcbb0050b0c8344f3
|
| 3 |
+
size 1856653
|
models/pwat/xgb_pwat6.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69b63396d339a95eeb853daf0fc8b60f02a3c7065b8ef4f879bcab2833d6e4b4
|
| 3 |
+
size 1650293
|
models/pwat/xgb_pwat7.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d32c766bfd592d643632425fe422826551c9d093d34b241237f986980b528cc9
|
| 3 |
+
size 1556397
|
models/pwat/xgb_pwat8.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1dd0c58d57755d4459d0ceb71a2c2a8fd858735334516cba8e0d948ea5f2d9e
|
| 3 |
+
size 1570379
|
pipeline.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WoundNetB7 End-to-End Pipeline: Image -> Segmentation -> PWAT + Fitzpatrick/ITA.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
from pipeline import WoundNetB7Pipeline
|
| 5 |
+
pipe = WoundNetB7Pipeline("models/")
|
| 6 |
+
result = pipe.analyze("path/to/dfu_image.png")
|
| 7 |
+
print(result)
|
| 8 |
+
"""
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dataclasses import dataclass, field, asdict
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from src.segmentation import load_segmentation_model, segment, CLASS_NAMES, CLASS_COLORS
|
| 17 |
+
from src.fitzpatrick_estimator import estimate_fitzpatrick, FitzpatrickResult
|
| 18 |
+
from src.pwat_estimator import PWATPredictor, PWATResult, ITEM_NAMES
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class AnalysisResult:
|
| 23 |
+
"""Complete DFU analysis result from a single image."""
|
| 24 |
+
# Segmentation
|
| 25 |
+
class_distribution: dict = field(default_factory=dict) # {class_name: percentage}
|
| 26 |
+
# Fitzpatrick
|
| 27 |
+
fitzpatrick: Optional[FitzpatrickResult] = None
|
| 28 |
+
# PWAT
|
| 29 |
+
pwat: Optional[PWATResult] = None
|
| 30 |
+
# Metadata
|
| 31 |
+
image_size: tuple = (0, 0)
|
| 32 |
+
device: str = "cpu"
|
| 33 |
+
|
| 34 |
+
def summary(self) -> str:
|
| 35 |
+
lines = ["=" * 50, "WoundNetB7 DFU Analysis", "=" * 50]
|
| 36 |
+
lines.append(f"Image: {self.image_size[1]}x{self.image_size[0]}")
|
| 37 |
+
lines.append(f"Device: {self.device}")
|
| 38 |
+
|
| 39 |
+
lines.append("\n--- Segmentation ---")
|
| 40 |
+
for cls, pct in self.class_distribution.items():
|
| 41 |
+
lines.append(f" {cls:<15s}: {pct:5.1f}%")
|
| 42 |
+
|
| 43 |
+
if self.fitzpatrick:
|
| 44 |
+
f = self.fitzpatrick
|
| 45 |
+
lines.append("\n--- Fitzpatrick / ITA ---")
|
| 46 |
+
lines.append(f" Type: {f.fitzpatrick_type} ({f.fitzpatrick_label})")
|
| 47 |
+
lines.append(f" ITA: {f.ita_angle:.1f} +/- {f.ita_std:.1f}")
|
| 48 |
+
lines.append(f" L* mean: {f.l_skin_mean:.1f}")
|
| 49 |
+
lines.append(f" Confidence: {f.confidence:.2f}")
|
| 50 |
+
lines.append(f" Pixels: {f.healthy_pixels:,}")
|
| 51 |
+
|
| 52 |
+
if self.pwat and self.pwat.scores_raw:
|
| 53 |
+
p = self.pwat
|
| 54 |
+
lines.append("\n--- PWAT Scores (Items 3-8) ---")
|
| 55 |
+
lines.append(f" {'Item':<22s} {'Raw':>4s} {'Adj':>5s}")
|
| 56 |
+
lines.append(" " + "-" * 33)
|
| 57 |
+
for item in [3, 4, 5, 6, 7, 8]:
|
| 58 |
+
name = ITEM_NAMES.get(item, f"Item {item}")
|
| 59 |
+
raw = p.scores_raw.get(item, "-")
|
| 60 |
+
adj = p.scores_adjusted.get(item, "-")
|
| 61 |
+
lines.append(f" {name:<22s} {raw:>4} {adj:>5.1f}")
|
| 62 |
+
lines.append(f" {'TOTAL':<22s} {p.total_raw:>4} {p.total_adjusted:>5.1f}")
|
| 63 |
+
lines.append(f" Fitzpatrick correction: {p.fitzpatrick_type}")
|
| 64 |
+
|
| 65 |
+
return "\n".join(lines)
|
| 66 |
+
|
| 67 |
+
def to_dict(self) -> dict:
|
| 68 |
+
d = {"image_size": self.image_size, "device": self.device, "class_distribution": self.class_distribution}
|
| 69 |
+
if self.fitzpatrick:
|
| 70 |
+
d["fitzpatrick"] = asdict(self.fitzpatrick)
|
| 71 |
+
if self.pwat:
|
| 72 |
+
d["pwat"] = {
|
| 73 |
+
"scores_raw": self.pwat.scores_raw,
|
| 74 |
+
"scores_adjusted": self.pwat.scores_adjusted,
|
| 75 |
+
"total_raw": self.pwat.total_raw,
|
| 76 |
+
"total_adjusted": self.pwat.total_adjusted,
|
| 77 |
+
}
|
| 78 |
+
return d
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class WoundNetB7Pipeline:
|
| 82 |
+
"""End-to-end DFU analysis pipeline.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
models_dir: Path to models/ directory containing:
|
| 86 |
+
- segmentation/WoundNetB7_proposed_best.pt
|
| 87 |
+
- pwat/xgb_pwat{3-8}.pkl
|
| 88 |
+
device: "cuda" or "cpu" (auto-detected if None)
|
| 89 |
+
use_tta: Use test-time augmentation for segmentation (slower but more accurate)
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, models_dir: str = "models", device: Optional[str] = None, use_tta: bool = True):
|
| 93 |
+
self.models_dir = Path(models_dir)
|
| 94 |
+
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 95 |
+
self.use_tta = use_tta
|
| 96 |
+
|
| 97 |
+
# Load segmentation model
|
| 98 |
+
seg_path = self.models_dir / "segmentation" / "WoundNetB7_proposed_best.pt"
|
| 99 |
+
self.seg_model = load_segmentation_model(str(seg_path), self.device)
|
| 100 |
+
print(f"Segmentation model loaded ({sum(p.numel() for p in self.seg_model.parameters()) / 1e6:.1f}M params)")
|
| 101 |
+
|
| 102 |
+
# Load PWAT predictor
|
| 103 |
+
pwat_path = self.models_dir / "pwat"
|
| 104 |
+
self.pwat_predictor = PWATPredictor(str(pwat_path))
|
| 105 |
+
print(f"PWAT models loaded ({len(self.pwat_predictor.models)} items)")
|
| 106 |
+
|
| 107 |
+
print(f"Pipeline ready on {self.device}")
|
| 108 |
+
|
| 109 |
+
def analyze(self, image_input, use_tta: Optional[bool] = None) -> AnalysisResult:
|
| 110 |
+
"""Analyze a DFU image end-to-end.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
image_input: file path (str/Path), BGR numpy array, or RGB numpy array
|
| 114 |
+
use_tta: Override TTA setting for this call
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
AnalysisResult with segmentation, Fitzpatrick, and PWAT data.
|
| 118 |
+
"""
|
| 119 |
+
tta = use_tta if use_tta is not None else self.use_tta
|
| 120 |
+
|
| 121 |
+
# Load image
|
| 122 |
+
if isinstance(image_input, (str, Path)):
|
| 123 |
+
img_bgr = cv2.imread(str(image_input))
|
| 124 |
+
if img_bgr is None:
|
| 125 |
+
raise FileNotFoundError(f"Cannot read image: {image_input}")
|
| 126 |
+
elif isinstance(image_input, np.ndarray):
|
| 127 |
+
img_bgr = image_input if image_input.shape[2] == 3 else cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
|
| 128 |
+
else:
|
| 129 |
+
raise TypeError(f"Unsupported input type: {type(image_input)}")
|
| 130 |
+
|
| 131 |
+
h, w = img_bgr.shape[:2]
|
| 132 |
+
|
| 133 |
+
# Step 1: Segmentation
|
| 134 |
+
seg = segment(self.seg_model, img_bgr, self.device, use_tta=tta)
|
| 135 |
+
classmap = seg["classmap"]
|
| 136 |
+
|
| 137 |
+
class_dist = {}
|
| 138 |
+
for cid, name in CLASS_NAMES.items():
|
| 139 |
+
class_dist[name] = round(float(np.mean(classmap == cid) * 100), 1)
|
| 140 |
+
|
| 141 |
+
# Step 2: Fitzpatrick estimation (from healthy skin)
|
| 142 |
+
fitz = estimate_fitzpatrick(img_bgr, seg["masks"])
|
| 143 |
+
|
| 144 |
+
# Step 3: PWAT prediction (from ulcer mask)
|
| 145 |
+
ulcer_mask = (classmap == 3).astype(np.uint8) * 255
|
| 146 |
+
pwat = self.pwat_predictor.predict(img_bgr, ulcer_mask, fitzpatrick_type=fitz.fitzpatrick_type)
|
| 147 |
+
|
| 148 |
+
return AnalysisResult(
|
| 149 |
+
class_distribution=class_dist,
|
| 150 |
+
fitzpatrick=fitz,
|
| 151 |
+
pwat=pwat,
|
| 152 |
+
image_size=(h, w),
|
| 153 |
+
device=str(self.device),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def visualize(self, img_bgr: np.ndarray, result: AnalysisResult) -> np.ndarray:
|
| 157 |
+
"""Create overlay visualization of segmentation result."""
|
| 158 |
+
h, w = img_bgr.shape[:2]
|
| 159 |
+
seg = segment(self.seg_model, img_bgr, self.device, use_tta=False)
|
| 160 |
+
classmap = seg["classmap"]
|
| 161 |
+
|
| 162 |
+
overlay = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 163 |
+
for cid, color in CLASS_COLORS.items():
|
| 164 |
+
if cid == 0:
|
| 165 |
+
continue
|
| 166 |
+
mask = classmap == cid
|
| 167 |
+
overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
|
| 168 |
+
|
| 169 |
+
return overlay.astype(np.uint8)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
segmentation-models-pytorch>=0.4.0
|
| 4 |
+
timm>=1.0.3
|
| 5 |
+
opencv-python-headless>=4.8.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
scikit-learn>=1.3.0
|
| 9 |
+
xgboost>=2.0.0
|
| 10 |
+
joblib>=1.3.0
|
| 11 |
+
gradio>=4.0.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""WoundNetB7 DFU Analysis Pipeline — Segmentation + PWAT + Fitzpatrick/ITA."""
|
src/fitzpatrick_estimator.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fitzpatrick skin type estimation via ITA (Individual Typology Angle).
|
| 2 |
+
|
| 3 |
+
Calibrated on 61 DFU images with expert ground truth.
|
| 4 |
+
Validation: 86.9% exact match, 98.4% adjacent, r=0.975.
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
# Calibrated ITA thresholds for DFU clinical photography (61-image validation)
|
| 12 |
+
ITA_THRESHOLDS = {
|
| 13 |
+
"I": (46.86, float("inf")),
|
| 14 |
+
"II": (34.25, 46.86),
|
| 15 |
+
"III": (20.87, 34.25),
|
| 16 |
+
"IV": (3.57, 20.87),
|
| 17 |
+
"V": (-28.38, 3.57),
|
| 18 |
+
"VI": (float("-inf"), -28.38),
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
FITZPATRICK_LABELS = {
|
| 22 |
+
"I": "Very Light", "II": "Light", "III": "Intermediate",
|
| 23 |
+
"IV": "Tan", "V": "Brown", "VI": "Dark",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class FitzpatrickResult:
|
| 29 |
+
fitzpatrick_type: str # "I" .. "VI"
|
| 30 |
+
fitzpatrick_int: int # 1 .. 6
|
| 31 |
+
fitzpatrick_label: str # "Very Light" .. "Dark"
|
| 32 |
+
ita_angle: float # ITA in degrees
|
| 33 |
+
ita_std: float # ITA standard deviation
|
| 34 |
+
l_skin_mean: float # Mean L* of healthy skin
|
| 35 |
+
b_skin_mean: float # Mean b* of healthy skin
|
| 36 |
+
healthy_pixels: int # Number of healthy skin pixels used
|
| 37 |
+
healthy_ratio: float # Healthy pixels / total image pixels
|
| 38 |
+
confidence: float # 0-1 confidence score
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def compute_ita(l_values: np.ndarray, b_values: np.ndarray) -> tuple:
|
| 42 |
+
"""Compute ITA angle from L* and b* values with robust trimming.
|
| 43 |
+
|
| 44 |
+
ITA = arctan((L* - 50) / b*) * (180 / pi)
|
| 45 |
+
Higher ITA = lighter skin, lower ITA = darker skin.
|
| 46 |
+
"""
|
| 47 |
+
ita_per_pixel = np.degrees(np.arctan2(l_values - 50.0, b_values))
|
| 48 |
+
# Robust trimming: 5th-95th percentile
|
| 49 |
+
p5, p95 = np.percentile(ita_per_pixel, [5, 95])
|
| 50 |
+
trimmed = ita_per_pixel[(ita_per_pixel >= p5) & (ita_per_pixel <= p95)]
|
| 51 |
+
if len(trimmed) < 10:
|
| 52 |
+
trimmed = ita_per_pixel
|
| 53 |
+
return float(np.mean(trimmed)), float(np.std(trimmed))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def classify_fitzpatrick(ita: float) -> tuple:
|
| 57 |
+
"""Classify ITA angle into Fitzpatrick type using calibrated DFU thresholds."""
|
| 58 |
+
for ftype, (lo, hi) in ITA_THRESHOLDS.items():
|
| 59 |
+
if lo <= ita < hi:
|
| 60 |
+
idx = list(ITA_THRESHOLDS.keys()).index(ftype) + 1
|
| 61 |
+
return ftype, idx
|
| 62 |
+
return "III", 3 # Default fallback
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def estimate_fitzpatrick(
|
| 66 |
+
img_bgr: np.ndarray,
|
| 67 |
+
masks: dict,
|
| 68 |
+
periulcer_dilation_px: int = 40,
|
| 69 |
+
) -> FitzpatrickResult:
|
| 70 |
+
"""Estimate Fitzpatrick type from a DFU image using segmentation masks.
|
| 71 |
+
|
| 72 |
+
Strategy: Healthy skin = foot region - perilesion zone - ulcer.
|
| 73 |
+
ITA is computed on the healthy skin pixels only.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
img_bgr: BGR image (H, W, 3)
|
| 77 |
+
masks: dict with keys 'foot', 'perilesion', 'ulcer' (bool arrays H, W)
|
| 78 |
+
periulcer_dilation_px: Extra dilation around wound for safety margin
|
| 79 |
+
"""
|
| 80 |
+
h, w = img_bgr.shape[:2]
|
| 81 |
+
foot = masks.get("foot", np.ones((h, w), dtype=bool))
|
| 82 |
+
peri = masks.get("perilesion", np.zeros((h, w), dtype=bool))
|
| 83 |
+
ulcer = masks.get("ulcer", np.zeros((h, w), dtype=bool))
|
| 84 |
+
|
| 85 |
+
# Dilate ulcer+perilesion for safety margin
|
| 86 |
+
exclusion = (peri | ulcer).astype(np.uint8)
|
| 87 |
+
if periulcer_dilation_px > 0:
|
| 88 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (periulcer_dilation_px, periulcer_dilation_px))
|
| 89 |
+
exclusion = cv2.dilate(exclusion, kernel)
|
| 90 |
+
exclusion = exclusion.astype(bool)
|
| 91 |
+
|
| 92 |
+
# Healthy skin = foot minus exclusion zone
|
| 93 |
+
healthy = foot & ~exclusion
|
| 94 |
+
healthy_pixels = int(np.sum(healthy))
|
| 95 |
+
|
| 96 |
+
if healthy_pixels < 100:
|
| 97 |
+
# Fallback: use all foot pixels minus wound
|
| 98 |
+
healthy = foot & ~ulcer
|
| 99 |
+
healthy_pixels = int(np.sum(healthy))
|
| 100 |
+
|
| 101 |
+
if healthy_pixels < 50:
|
| 102 |
+
return FitzpatrickResult(
|
| 103 |
+
fitzpatrick_type="III", fitzpatrick_int=3,
|
| 104 |
+
fitzpatrick_label="Intermediate",
|
| 105 |
+
ita_angle=0.0, ita_std=0.0,
|
| 106 |
+
l_skin_mean=0.0, b_skin_mean=0.0,
|
| 107 |
+
healthy_pixels=healthy_pixels,
|
| 108 |
+
healthy_ratio=healthy_pixels / (h * w),
|
| 109 |
+
confidence=0.0,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Convert to L*a*b*
|
| 113 |
+
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
|
| 114 |
+
l_values = lab[healthy, 0] * (100.0 / 255.0) # OpenCV L* is 0-255 -> 0-100
|
| 115 |
+
b_values = lab[healthy, 2] - 128.0 # OpenCV b* is 0-255 -> -128 to +127
|
| 116 |
+
|
| 117 |
+
ita_mean, ita_std = compute_ita(l_values, b_values)
|
| 118 |
+
ftype, fint = classify_fitzpatrick(ita_mean)
|
| 119 |
+
|
| 120 |
+
# Confidence: pixel count + ITA consistency + coverage
|
| 121 |
+
pixel_conf = min(healthy_pixels / 5000.0, 1.0)
|
| 122 |
+
ita_conf = max(0.0, 1.0 - (ita_std / 30.0))
|
| 123 |
+
coverage_conf = min((healthy_pixels / (h * w)) / 0.15, 1.0)
|
| 124 |
+
confidence = pixel_conf * 0.3 + ita_conf * 0.4 + coverage_conf * 0.3
|
| 125 |
+
|
| 126 |
+
return FitzpatrickResult(
|
| 127 |
+
fitzpatrick_type=ftype,
|
| 128 |
+
fitzpatrick_int=fint,
|
| 129 |
+
fitzpatrick_label=FITZPATRICK_LABELS[ftype],
|
| 130 |
+
ita_angle=round(ita_mean, 2),
|
| 131 |
+
ita_std=round(ita_std, 2),
|
| 132 |
+
l_skin_mean=round(float(np.mean(l_values)), 2),
|
| 133 |
+
b_skin_mean=round(float(np.mean(b_values)), 2),
|
| 134 |
+
healthy_pixels=healthy_pixels,
|
| 135 |
+
healthy_ratio=round(healthy_pixels / (h * w), 4),
|
| 136 |
+
confidence=round(confidence, 3),
|
| 137 |
+
)
|
src/pwat_estimator.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PWAT (Photographic Wound Assessment Tool) estimation — Items 3-8.
|
| 2 |
+
|
| 3 |
+
Uses segmentation masks to extract color, tissue, morphological, and texture
|
| 4 |
+
features, then predicts ordinal PWAT scores (0-4) via XGBoost classifiers.
|
| 5 |
+
|
| 6 |
+
Includes Fitzpatrick-aware debiasing correction factors.
|
| 7 |
+
"""
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import joblib
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Optional
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
ITEMS = [3, 4, 5, 6, 7, 8]
|
| 16 |
+
ITEM_NAMES = {
|
| 17 |
+
3: "Necrotic Type",
|
| 18 |
+
4: "Necrotic Amount",
|
| 19 |
+
5: "Granulation Type",
|
| 20 |
+
6: "Granulation Amount",
|
| 21 |
+
7: "Edges",
|
| 22 |
+
8: "Periulcer Skin",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# Debiasing correction factors (calibrated from 61 DFU images)
|
| 26 |
+
# Applied as: adjusted = clip(raw + factor, 0, 4)
|
| 27 |
+
CORRECTION_FACTORS = {
|
| 28 |
+
"I": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
|
| 29 |
+
"II": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0},
|
| 30 |
+
"III": {3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.1},
|
| 31 |
+
"IV": {3: -0.1, 4: -0.1, 5: 0.0, 6: 0.0, 7: 0.0, 8: -0.3},
|
| 32 |
+
"V": {3: -0.2, 4: -0.2, 5: -0.1, 6: 0.0, 7: 0.0, 8: -0.6},
|
| 33 |
+
"VI": {3: -0.3, 4: -0.3, 5: -0.2, 6: -0.1, 7: 0.0, 8: -0.9},
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class PWATResult:
|
| 39 |
+
scores_raw: dict = field(default_factory=dict) # {item: int}
|
| 40 |
+
scores_adjusted: dict = field(default_factory=dict) # {item: float} (debiased)
|
| 41 |
+
total_raw: int = 0
|
| 42 |
+
total_adjusted: float = 0.0
|
| 43 |
+
fitzpatrick_type: str = ""
|
| 44 |
+
features: dict = field(default_factory=dict)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extract_features(img_bgr: np.ndarray, ulcer_mask: np.ndarray) -> Optional[dict]:
|
| 48 |
+
"""Extract 63 features from the wound region for PWAT prediction.
|
| 49 |
+
|
| 50 |
+
Features: color (RGB/HSV/Lab), tissue composition, morphology, texture.
|
| 51 |
+
"""
|
| 52 |
+
b = ulcer_mask > 0 if ulcer_mask.dtype == bool else ulcer_mask > 127
|
| 53 |
+
npx = int(np.sum(b))
|
| 54 |
+
if npx < 50:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
feats = {}
|
| 58 |
+
|
| 59 |
+
# --- Color features (45) ---
|
| 60 |
+
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV).astype(np.float32)
|
| 61 |
+
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2Lab).astype(np.float32)
|
| 62 |
+
rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
|
| 63 |
+
|
| 64 |
+
for cs, arr, names in [("rgb", rgb, ["R", "G", "B"]), ("hsv", hsv, ["H", "S", "V"]), ("lab", lab, ["L", "a", "b"])]:
|
| 65 |
+
for ci, cn in enumerate(names):
|
| 66 |
+
vals = arr[b, ci]
|
| 67 |
+
feats[f"{cs}_{cn}_mean"] = float(np.mean(vals))
|
| 68 |
+
feats[f"{cs}_{cn}_std"] = float(np.std(vals))
|
| 69 |
+
feats[f"{cs}_{cn}_median"] = float(np.median(vals))
|
| 70 |
+
feats[f"{cs}_{cn}_p25"] = float(np.percentile(vals, 25))
|
| 71 |
+
feats[f"{cs}_{cn}_p75"] = float(np.percentile(vals, 75))
|
| 72 |
+
|
| 73 |
+
# --- Tissue composition (5) ---
|
| 74 |
+
h, s, v = hsv[b, 0], hsv[b, 1], hsv[b, 2]
|
| 75 |
+
l_ch = lab[b, 0] * (100 / 255)
|
| 76 |
+
a_ch = lab[b, 1] - 128
|
| 77 |
+
|
| 78 |
+
eschar = ((v < 100) & (s < 60)) | (v < 60)
|
| 79 |
+
slough = (h >= 15) & (h <= 50) & (s > 25) & (v > 70) & ~eschar
|
| 80 |
+
gran = (((h < 15) | (h > 155)) & (s > 35) & (v > 60) & (a_ch > 5)) & ~eschar
|
| 81 |
+
necro = (s < 45) & (v >= 60) & (v < 160) & (l_ch < 55) & ~eschar & ~gran
|
| 82 |
+
|
| 83 |
+
feats["tissue_gran_pct"] = float(np.sum(gran) / npx * 100)
|
| 84 |
+
feats["tissue_eschar_pct"] = float(np.sum(eschar) / npx * 100)
|
| 85 |
+
feats["tissue_slough_pct"] = float(np.sum(slough) / npx * 100)
|
| 86 |
+
feats["tissue_necro_pct"] = float(np.sum(necro) / npx * 100)
|
| 87 |
+
feats["tissue_necro_total"] = feats["tissue_eschar_pct"] + feats["tissue_slough_pct"] + feats["tissue_necro_pct"]
|
| 88 |
+
|
| 89 |
+
# --- Morphological features (7) ---
|
| 90 |
+
mask_u8 = b.astype(np.uint8) if b.dtype == bool else (ulcer_mask > 127).astype(np.uint8)
|
| 91 |
+
cnts, _ = cv2.findContours(mask_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 92 |
+
if cnts:
|
| 93 |
+
cnt = max(cnts, key=cv2.contourArea)
|
| 94 |
+
area = cv2.contourArea(cnt)
|
| 95 |
+
perim = cv2.arcLength(cnt, True)
|
| 96 |
+
circ = 4 * np.pi * area / (perim**2) if perim > 0 else 0
|
| 97 |
+
feats["morph_area"] = float(area)
|
| 98 |
+
feats["morph_perimeter"] = float(perim)
|
| 99 |
+
feats["morph_circularity"] = float(circ)
|
| 100 |
+
feats["morph_irregularity"] = float(1 - circ)
|
| 101 |
+
x, y, w2, h2 = cv2.boundingRect(cnt)
|
| 102 |
+
feats["morph_aspect_ratio"] = float(w2 / (h2 + 1e-8))
|
| 103 |
+
feats["morph_extent"] = float(area / (w2 * h2 + 1e-8))
|
| 104 |
+
hull = cv2.convexHull(cnt)
|
| 105 |
+
feats["morph_solidity"] = float(area / (cv2.contourArea(hull) + 1e-8))
|
| 106 |
+
|
| 107 |
+
# --- Texture features (4) ---
|
| 108 |
+
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
|
| 109 |
+
wound_gray = gray[b]
|
| 110 |
+
feats["texture_mean"] = float(np.mean(wound_gray))
|
| 111 |
+
feats["texture_std"] = float(np.std(wound_gray))
|
| 112 |
+
hist_vals = np.histogram(wound_gray, bins=64, density=True)[0]
|
| 113 |
+
feats["texture_entropy"] = float(-np.sum(hist_vals * np.log2(hist_vals + 1e-10)))
|
| 114 |
+
|
| 115 |
+
dilated = cv2.dilate(mask_u8 * 255, np.ones((5, 5), np.uint8))
|
| 116 |
+
eroded = cv2.erode(mask_u8 * 255, np.ones((5, 5), np.uint8))
|
| 117 |
+
edge_zone = (dilated - eroded) > 127
|
| 118 |
+
if np.any(edge_zone):
|
| 119 |
+
feats["edge_gradient"] = float(np.mean(np.abs(cv2.Sobel(gray.astype(np.float32), cv2.CV_32F, 1, 0)[edge_zone])))
|
| 120 |
+
|
| 121 |
+
# --- ROI features (2) ---
|
| 122 |
+
feats["wound_npx"] = float(npx)
|
| 123 |
+
feats["wound_ratio"] = float(npx / (img_bgr.shape[0] * img_bgr.shape[1]))
|
| 124 |
+
|
| 125 |
+
return feats
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class PWATPredictor:
|
| 129 |
+
"""Predicts PWAT items 3-8 from wound features using trained XGBoost models."""
|
| 130 |
+
|
| 131 |
+
def __init__(self, models_dir: str):
|
| 132 |
+
self.models = {}
|
| 133 |
+
models_path = Path(models_dir)
|
| 134 |
+
for item in ITEMS:
|
| 135 |
+
pkl = models_path / f"xgb_pwat{item}.pkl"
|
| 136 |
+
if pkl.exists():
|
| 137 |
+
self.models[item] = joblib.load(pkl)
|
| 138 |
+
|
| 139 |
+
def predict(
|
| 140 |
+
self,
|
| 141 |
+
img_bgr: np.ndarray,
|
| 142 |
+
ulcer_mask: np.ndarray,
|
| 143 |
+
fitzpatrick_type: str = "III",
|
| 144 |
+
feature_cols: Optional[list] = None,
|
| 145 |
+
) -> PWATResult:
|
| 146 |
+
"""Predict PWAT scores for a single image.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
img_bgr: BGR image
|
| 150 |
+
ulcer_mask: Binary ulcer mask (H, W)
|
| 151 |
+
fitzpatrick_type: Fitzpatrick type for debiasing ("I" .. "VI")
|
| 152 |
+
feature_cols: Ordered feature column names (must match training order).
|
| 153 |
+
If None, uses all extracted features sorted alphabetically.
|
| 154 |
+
"""
|
| 155 |
+
feats = extract_features(img_bgr, ulcer_mask)
|
| 156 |
+
if feats is None:
|
| 157 |
+
return PWATResult(fitzpatrick_type=fitzpatrick_type)
|
| 158 |
+
|
| 159 |
+
# Build feature vector
|
| 160 |
+
if feature_cols is None:
|
| 161 |
+
feature_cols = sorted(feats.keys())
|
| 162 |
+
X = np.array([[feats.get(c, 0.0) for c in feature_cols]])
|
| 163 |
+
|
| 164 |
+
scores_raw = {}
|
| 165 |
+
scores_adj = {}
|
| 166 |
+
for item in ITEMS:
|
| 167 |
+
if item in self.models:
|
| 168 |
+
pred = int(self.models[item].predict(X)[0])
|
| 169 |
+
scores_raw[item] = pred
|
| 170 |
+
factor = CORRECTION_FACTORS.get(fitzpatrick_type, {}).get(item, 0.0)
|
| 171 |
+
scores_adj[item] = float(np.clip(pred + factor, 0, 4))
|
| 172 |
+
else:
|
| 173 |
+
scores_raw[item] = 0
|
| 174 |
+
scores_adj[item] = 0.0
|
| 175 |
+
|
| 176 |
+
return PWATResult(
|
| 177 |
+
scores_raw=scores_raw,
|
| 178 |
+
scores_adjusted=scores_adj,
|
| 179 |
+
total_raw=sum(scores_raw.values()),
|
| 180 |
+
total_adjusted=round(sum(scores_adj.values()), 1),
|
| 181 |
+
fitzpatrick_type=fitzpatrick_type,
|
| 182 |
+
features=feats,
|
| 183 |
+
)
|
src/segmentation.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WoundNetB7 multiclass segmentation model — 4 classes (bg, foot, perilesion, ulcer).
|
| 2 |
+
|
| 3 |
+
Architecture: EfficientNet-B7 encoder + ASPP + CBAM + TAM + UNet decoder.
|
| 4 |
+
Checkpoint: Track B multiclass, ulcer Dice = 0.927 (Bootstrap CI: [0.917, 0.936]).
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import segmentation_models_pytorch as smp
|
| 10 |
+
import numpy as np
|
| 11 |
+
import cv2
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
IMG_SIZE = 512
|
| 15 |
+
MEAN = np.array([0.485, 0.456, 0.406])
|
| 16 |
+
STD = np.array([0.229, 0.224, 0.225])
|
| 17 |
+
CLASS_NAMES = {0: "background", 1: "foot", 2: "perilesion", 3: "ulcer"}
|
| 18 |
+
CLASS_COLORS = {
|
| 19 |
+
0: (0, 0, 0),
|
| 20 |
+
1: (0, 255, 0),
|
| 21 |
+
2: (255, 165, 0),
|
| 22 |
+
3: (255, 0, 0),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Architecture modules (match checkpoint weights exactly)
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
class ChannelAttention(nn.Module):
|
| 31 |
+
def __init__(self, channels, reduction=16):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(channels, channels // reduction, bias=False),
|
| 35 |
+
nn.ReLU(inplace=True),
|
| 36 |
+
nn.Linear(channels // reduction, channels, bias=False),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
avg_out = self.mlp(x.mean(dim=[2, 3]))
|
| 41 |
+
max_out = self.mlp(x.amax(dim=[2, 3]))
|
| 42 |
+
attn = torch.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
|
| 43 |
+
return x * attn
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SpatialAttention(nn.Module):
|
| 47 |
+
def __init__(self, kernel_size=7):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
avg_out = x.mean(dim=1, keepdim=True)
|
| 53 |
+
max_out = x.amax(dim=1, keepdim=True)
|
| 54 |
+
attn = torch.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
|
| 55 |
+
return x * attn
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class CBAM(nn.Module):
|
| 59 |
+
def __init__(self, channels, reduction=16, kernel_size=7):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.ca = ChannelAttention(channels, reduction)
|
| 62 |
+
self.sa = SpatialAttention(kernel_size)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.sa(self.ca(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DifferentiableFractalDimension(nn.Module):
|
| 69 |
+
def __init__(self, scales=None):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.scales = scales or [2, 4, 8, 16, 32]
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
B, C, H, W = x.shape
|
| 75 |
+
counts = []
|
| 76 |
+
for s in self.scales:
|
| 77 |
+
if s >= H or s >= W:
|
| 78 |
+
continue
|
| 79 |
+
pooled = F.avg_pool2d(x, kernel_size=s, stride=s)
|
| 80 |
+
n_boxes = torch.sigmoid(10.0 * (pooled - 0.1)).sum(dim=[2, 3])
|
| 81 |
+
counts.append(n_boxes)
|
| 82 |
+
if len(counts) < 2:
|
| 83 |
+
return torch.ones(B, C, device=x.device)
|
| 84 |
+
log_s = torch.log(torch.tensor([float(s) for s in self.scales[: len(counts)]], device=x.device))
|
| 85 |
+
log_c = torch.stack([torch.log(c + 1) for c in counts], dim=-1)
|
| 86 |
+
n = log_s.shape[0]
|
| 87 |
+
sx, sxx = log_s.sum(), (log_s**2).sum()
|
| 88 |
+
sy = log_c.sum(dim=-1)
|
| 89 |
+
sxy = (log_c * log_s.unsqueeze(0).unsqueeze(0)).sum(dim=-1)
|
| 90 |
+
slope = (n * sxy - sx * sy) / (n * sxx - sx**2 + 1e-8)
|
| 91 |
+
return -slope.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DifferentiableEulerCharacteristic(nn.Module):
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
B, C, H, W = x.shape
|
| 97 |
+
b = torch.sigmoid(10.0 * (torch.sigmoid(x) - 0.5))
|
| 98 |
+
V = b.sum(dim=[2, 3])
|
| 99 |
+
E_h = (b[:, :, :, :-1] * b[:, :, :, 1:]).sum(dim=[2, 3])
|
| 100 |
+
E_v = (b[:, :, :-1, :] * b[:, :, 1:, :]).sum(dim=[2, 3])
|
| 101 |
+
F_val = (b[:, :, :-1, :-1] * b[:, :, :-1, 1:] * b[:, :, 1:, :-1] * b[:, :, 1:, 1:]).sum(dim=[2, 3])
|
| 102 |
+
euler = V - E_h - E_v + F_val
|
| 103 |
+
return euler.mean(dim=1, keepdim=True).unsqueeze(-1).unsqueeze(-1) / (H * W)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TopologicalAttentionModule(nn.Module):
|
| 107 |
+
def __init__(self, in_channels):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.fractal = DifferentiableFractalDimension()
|
| 110 |
+
self.euler = DifferentiableEulerCharacteristic()
|
| 111 |
+
self.alpha = nn.Parameter(torch.tensor(1.0))
|
| 112 |
+
self.beta = nn.Parameter(torch.tensor(1.0))
|
| 113 |
+
self.conv = nn.Sequential(
|
| 114 |
+
nn.Conv2d(in_channels + 2, in_channels, 1),
|
| 115 |
+
nn.BatchNorm2d(in_channels),
|
| 116 |
+
nn.ReLU(inplace=True),
|
| 117 |
+
nn.Conv2d(in_channels, in_channels, 1),
|
| 118 |
+
nn.Sigmoid(),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
B, C, H, W = x.shape
|
| 123 |
+
fm = self.fractal(x).expand(B, 1, H, W)
|
| 124 |
+
em = self.euler(x).expand(B, 1, H, W)
|
| 125 |
+
attn = self.conv(torch.cat([x, self.alpha * fm, self.beta * em], dim=1))
|
| 126 |
+
return x * attn + x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ASPP(nn.Module):
|
| 130 |
+
def __init__(self, in_ch, out_ch, rates=None):
|
| 131 |
+
super().__init__()
|
| 132 |
+
rates = rates or [6, 12, 18]
|
| 133 |
+
self.conv1x1 = nn.Sequential(nn.Conv2d(in_ch, out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True))
|
| 134 |
+
self.atrous = nn.ModuleList(
|
| 135 |
+
[nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=r, dilation=r), nn.BatchNorm2d(out_ch), nn.ReLU(True)) for r in rates]
|
| 136 |
+
)
|
| 137 |
+
self.pool = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_ch, out_ch, 1), nn.ReLU(True))
|
| 138 |
+
self.project = nn.Sequential(
|
| 139 |
+
nn.Conv2d(out_ch * (2 + len(rates)), out_ch, 1), nn.BatchNorm2d(out_ch), nn.ReLU(True), nn.Dropout(0.5)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
size = x.shape[2:]
|
| 144 |
+
feats = [self.conv1x1(x)] + [a(x) for a in self.atrous]
|
| 145 |
+
feats.append(F.interpolate(self.pool(x), size=size, mode="bilinear", align_corners=False))
|
| 146 |
+
return self.project(torch.cat(feats, dim=1))
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class WoundNetB7(nn.Module):
|
| 150 |
+
"""WoundNetB7 matching the Track B checkpoint structure."""
|
| 151 |
+
|
| 152 |
+
NUM_CLASSES = 4
|
| 153 |
+
|
| 154 |
+
def __init__(self, num_classes=4):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.backbone = smp.Unet(encoder_name="efficientnet-b7", encoder_weights=None, in_channels=3, classes=num_classes)
|
| 157 |
+
enc_ch = self.backbone.encoder.out_channels[-1]
|
| 158 |
+
self.aspp = ASPP(enc_ch, enc_ch)
|
| 159 |
+
self.cbam = CBAM(num_classes, reduction=max(1, num_classes // 2))
|
| 160 |
+
self.tam = TopologicalAttentionModule(num_classes)
|
| 161 |
+
self.diffusion_weight = nn.Parameter(torch.tensor(0.01))
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
features = list(self.backbone.encoder(x))
|
| 165 |
+
features[-1] = self.aspp(features[-1])
|
| 166 |
+
try:
|
| 167 |
+
dec = self.backbone.decoder(features)
|
| 168 |
+
except TypeError:
|
| 169 |
+
dec = self.backbone.decoder(*features)
|
| 170 |
+
seg = self.backbone.segmentation_head(dec)
|
| 171 |
+
seg = self.cbam(seg)
|
| 172 |
+
seg = self.tam(seg)
|
| 173 |
+
return seg
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
# Inference helpers
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
|
| 180 |
+
def preprocess(img_bgr: np.ndarray) -> torch.Tensor:
|
| 181 |
+
"""BGR image -> normalized CHW tensor (1, 3, 512, 512)."""
|
| 182 |
+
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 183 |
+
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
|
| 184 |
+
img = (img.astype(np.float32) / 255.0 - MEAN) / STD
|
| 185 |
+
return torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def tta_inference(model: nn.Module, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
|
| 189 |
+
"""6-fold TTA -> averaged softmax probabilities (1, C, H, W)."""
|
| 190 |
+
transforms = [
|
| 191 |
+
lambda x: x,
|
| 192 |
+
lambda x: torch.flip(x, [3]),
|
| 193 |
+
lambda x: torch.flip(x, [2]),
|
| 194 |
+
lambda x: torch.rot90(x, 1, [2, 3]),
|
| 195 |
+
lambda x: torch.rot90(x, 2, [2, 3]),
|
| 196 |
+
lambda x: torch.rot90(x, 3, [2, 3]),
|
| 197 |
+
]
|
| 198 |
+
inverse = [
|
| 199 |
+
lambda x: x,
|
| 200 |
+
lambda x: torch.flip(x, [3]),
|
| 201 |
+
lambda x: torch.flip(x, [2]),
|
| 202 |
+
lambda x: torch.rot90(x, 3, [2, 3]),
|
| 203 |
+
lambda x: torch.rot90(x, 2, [2, 3]),
|
| 204 |
+
lambda x: torch.rot90(x, 1, [2, 3]),
|
| 205 |
+
]
|
| 206 |
+
probs_sum = None
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
for tfm, inv in zip(transforms, inverse):
|
| 209 |
+
out = model(tfm(img_tensor).to(device))
|
| 210 |
+
if isinstance(out, (tuple, list)):
|
| 211 |
+
out = out[0]
|
| 212 |
+
if isinstance(out, dict):
|
| 213 |
+
out = out["seg"]
|
| 214 |
+
p = inv(F.softmax(out, dim=1))
|
| 215 |
+
probs_sum = p if probs_sum is None else probs_sum + p
|
| 216 |
+
return probs_sum / len(transforms)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def load_segmentation_model(checkpoint_path: str, device: torch.device) -> nn.Module:
|
| 220 |
+
"""Load WoundNetB7 from checkpoint."""
|
| 221 |
+
model = WoundNetB7(num_classes=4)
|
| 222 |
+
state = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 223 |
+
# Remove PWAT head keys if present
|
| 224 |
+
state = {k: v for k, v in state.items() if not k.startswith("pwat_head.")}
|
| 225 |
+
model.load_state_dict(state, strict=False)
|
| 226 |
+
model.to(device).eval()
|
| 227 |
+
return model
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def segment(model: nn.Module, img_bgr: np.ndarray, device: torch.device, use_tta: bool = True) -> dict:
|
| 231 |
+
"""Run segmentation on a BGR image.
|
| 232 |
+
|
| 233 |
+
Returns dict with:
|
| 234 |
+
classmap: (H, W) uint8 with class indices 0-3
|
| 235 |
+
masks: dict of per-class binary masks {cls_name: (H, W) bool}
|
| 236 |
+
probs: (4, H, W) float32 softmax probabilities
|
| 237 |
+
"""
|
| 238 |
+
h, w = img_bgr.shape[:2]
|
| 239 |
+
tensor = preprocess(img_bgr)
|
| 240 |
+
|
| 241 |
+
if use_tta:
|
| 242 |
+
probs = tta_inference(model, tensor, device)
|
| 243 |
+
else:
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
out = model(tensor.to(device))
|
| 246 |
+
if isinstance(out, (tuple, list)):
|
| 247 |
+
out = out[0]
|
| 248 |
+
if isinstance(out, dict):
|
| 249 |
+
out = out["seg"]
|
| 250 |
+
probs = F.softmax(out, dim=1)
|
| 251 |
+
|
| 252 |
+
probs_np = probs[0].cpu().numpy()
|
| 253 |
+
probs_resized = np.stack([cv2.resize(probs_np[c], (w, h), interpolation=cv2.INTER_LINEAR) for c in range(4)])
|
| 254 |
+
classmap = probs_resized.argmax(axis=0).astype(np.uint8)
|
| 255 |
+
masks = {name: (classmap == cid) for cid, name in CLASS_NAMES.items() if cid > 0}
|
| 256 |
+
return {"classmap": classmap, "masks": masks, "probs": probs_resized}
|