Spaces:
Running
Running
Initial Space: Gradio app + cv_baseline_fold5 (Test top-1 0.8152)
Browse files- README.md +40 -7
- app.py +123 -0
- cv_baseline_fold5_best.pt +3 -0
- requirements.txt +4 -0
README.md
CHANGED
|
@@ -1,13 +1,46 @@
|
|
| 1 |
---
|
| 2 |
-
title: Wound Classifier
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
python_version: '3.13'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Chronic Wound Classifier
|
| 3 |
+
emoji: 🩹
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Chronic Wound Classifier — 4-class AZH demo
|
| 14 |
+
|
| 15 |
+
Demo classifier for chronic wound photographs: predicts one of four wound types
|
| 16 |
+
(diabetic ulcer, pressure ulcer, surgical wound, venous ulcer) from an uploaded image.
|
| 17 |
+
|
| 18 |
+
**Not a medical device. Not for clinical use.** Research demonstration only.
|
| 19 |
+
|
| 20 |
+
## Headline metric
|
| 21 |
+
Top-1 accuracy on the held-out AZH Test set (n=184): **0.8152**
|
| 22 |
+
(`cv_baseline_fold5_best.pt` — the highest single-fold checkpoint from
|
| 23 |
+
patient-grouped 10-fold cross-validation).
|
| 24 |
+
|
| 25 |
+
The 10-fold soft-vote ensemble of the same recipe scores 0.7989 on the same
|
| 26 |
+
set; the single-checkpoint variant is shipped here for inference latency
|
| 27 |
+
and footprint reasons.
|
| 28 |
+
|
| 29 |
+
## Architecture
|
| 30 |
+
EfficientNet-B0 (ImageNet-pretrained), two-phase fine-tune (head-only 5 epochs
|
| 31 |
+
at lr=1e-3, then full network 15 epochs at lr=1e-4). Patient-grouped CV
|
| 32 |
+
splits ensure the same patient's images never appear in both train and val.
|
| 33 |
+
|
| 34 |
+
## Limitations
|
| 35 |
+
- **Pressure-class accuracy is ~0.41** — interpret pressure-class predictions with care.
|
| 36 |
+
- No fairness audit across skin tones (known gap).
|
| 37 |
+
- English-only UI; no mobile or offline build.
|
| 38 |
+
- Not validated on real patient cohorts outside AZH.
|
| 39 |
+
|
| 40 |
+
## Source code & training pipeline
|
| 41 |
+
The training, evaluation, and methodology code live in the project repo:
|
| 42 |
+
[github.com — wound-classification](#) (full link to be added by user)
|
| 43 |
+
|
| 44 |
+
## Citation
|
| 45 |
+
Anisuzzaman et al. 2022. *Multi-modal wound classification using wound image
|
| 46 |
+
and location by deep neural network.* Sci. Rep. 12:20057.
|
app.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HF Spaces Gradio app — chronic wound classifier (4-class).
|
| 2 |
+
|
| 3 |
+
Self-contained: no wound_classifier package install required. The model
|
| 4 |
+
architecture and transforms are inlined here so this file plus the .pt
|
| 5 |
+
checkpoint and requirements.txt are everything the Space needs.
|
| 6 |
+
|
| 7 |
+
If the architecture or transform here drifts from
|
| 8 |
+
src/wound_classifier/{modeling/models.py, features.py} the Space and the
|
| 9 |
+
training pipeline will silently disagree. Keep them in sync.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch import nn
|
| 20 |
+
from torchvision import transforms
|
| 21 |
+
from torchvision.models import efficientnet_b0
|
| 22 |
+
|
| 23 |
+
CKPT_PATH = Path(__file__).parent / "cv_baseline_fold5_best.pt"
|
| 24 |
+
IMAGE_SIZE = 224
|
| 25 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 26 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 27 |
+
CLASS_NAMES = {
|
| 28 |
+
"D": "Diabetic ulcer",
|
| 29 |
+
"P": "Pressure ulcer",
|
| 30 |
+
"S": "Surgical wound",
|
| 31 |
+
"V": "Venous ulcer",
|
| 32 |
+
}
|
| 33 |
+
IDX_TO_CLASS = ["D", "P", "S", "V"]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _build_model(num_classes: int = 4) -> nn.Module:
|
| 37 |
+
model: nn.Module = efficientnet_b0(weights=None)
|
| 38 |
+
in_features = model.classifier[1].in_features # type: ignore[index, union-attr]
|
| 39 |
+
model.classifier = nn.Sequential(
|
| 40 |
+
nn.Dropout(p=0.2, inplace=True),
|
| 41 |
+
nn.Linear(in_features, num_classes),
|
| 42 |
+
)
|
| 43 |
+
return model
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _load_model(path: Path) -> nn.Module:
|
| 47 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 48 |
+
model = _build_model(num_classes=4)
|
| 49 |
+
model.load_state_dict(ckpt["state_dict"])
|
| 50 |
+
model.eval()
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _build_transform() -> transforms.Compose:
|
| 55 |
+
return transforms.Compose(
|
| 56 |
+
[
|
| 57 |
+
transforms.Resize(IMAGE_SIZE),
|
| 58 |
+
transforms.CenterCrop(IMAGE_SIZE),
|
| 59 |
+
transforms.ToTensor(),
|
| 60 |
+
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
MODEL = _load_model(CKPT_PATH)
|
| 66 |
+
TRANSFORM = _build_transform()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def classify(image: Image.Image) -> dict[str, float]:
|
| 70 |
+
if image is None:
|
| 71 |
+
return {}
|
| 72 |
+
rgb = image.convert("RGB")
|
| 73 |
+
x = TRANSFORM(rgb).unsqueeze(0)
|
| 74 |
+
with torch.inference_mode():
|
| 75 |
+
logits = MODEL(x)
|
| 76 |
+
probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
|
| 77 |
+
return {CLASS_NAMES[IDX_TO_CLASS[i]]: float(probs[i]) for i in range(4)}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
DESCRIPTION = """\
|
| 81 |
+
Chronic wound photo classifier — predicts which of four wound types an uploaded image shows
|
| 82 |
+
(diabetic ulcer, pressure ulcer, surgical wound, or venous ulcer).
|
| 83 |
+
|
| 84 |
+
**This is a research demo, not a medical device.** It does not diagnose, triage, or replace clinician
|
| 85 |
+
judgement. It was trained on the public AZH Chronic Wound Database (Anisuzzaman et al. 2022) using
|
| 86 |
+
EfficientNet-B0 with patient-grouped 10-fold cross-validation; the deployed checkpoint
|
| 87 |
+
(`cv_baseline_fold5_best.pt`) reaches **0.8152 top-1** on the held-out AZH Test set (n=184).
|
| 88 |
+
Pressure ulcers are the underperforming class (~0.41 accuracy) — interpret pressure-class predictions
|
| 89 |
+
with extra care.
|
| 90 |
+
|
| 91 |
+
Upload a wound photograph below to see the model's per-class probabilities.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
ARTICLE = """\
|
| 95 |
+
**Methodology in one paragraph.** EfficientNet-B0 (ImageNet-pretrained) two-phase fine-tuning:
|
| 96 |
+
phase 1 freezes the backbone and trains the 4-class head for 5 epochs at lr=1e-3; phase 2 unfreezes
|
| 97 |
+
and fine-tunes the full network for 15 epochs at lr=1e-4. Patient-grouped 10-fold CV ensures the same
|
| 98 |
+
patient's images never appear in both train and val. Training and validation under random-stratified
|
| 99 |
+
splits would inflate accuracy by ~3–30pp on AZH (we measured this directly); we report the patient-
|
| 100 |
+
grouped numbers as the honest evaluation. The shipped checkpoint is the highest single fold on the
|
| 101 |
+
held-out Test set; a 10-fold soft-vote ensemble averages 0.7989 on the same set.
|
| 102 |
+
|
| 103 |
+
**Out of scope.** Not for clinical decision-making. No claim of diagnostic accuracy on real patient
|
| 104 |
+
cohorts. No fairness audit across skin tones (known gap). English-only UI. No mobile / offline build.
|
| 105 |
+
|
| 106 |
+
**Citations.** Anisuzzaman et al. 2022, *Multi-modal wound classification using wound image and
|
| 107 |
+
location by deep neural network*, Sci. Rep. 12:20057.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
demo = gr.Interface(
|
| 112 |
+
fn=classify,
|
| 113 |
+
inputs=gr.Image(type="pil", label="Wound photograph"),
|
| 114 |
+
outputs=gr.Label(num_top_classes=4, label="Predicted wound type"),
|
| 115 |
+
title="Chronic Wound Classifier — 4-class AZH demo",
|
| 116 |
+
description=DESCRIPTION,
|
| 117 |
+
article=ARTICLE,
|
| 118 |
+
flagging_mode="never",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
demo.launch()
|
cv_baseline_fold5_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f4b21f6204855827f33f2679be63d964d237a4d73a9940382ef4d49ed552e99
|
| 3 |
+
size 16355126
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
torchvision>=0.15
|
| 3 |
+
gradio>=4.0
|
| 4 |
+
pillow>=10.0
|