jbobym commited on
Commit
814f7c0
·
verified ·
1 Parent(s): 15824a6

Initial Space: Gradio app + cv_baseline_fold5 (Test top-1 0.8152)

Browse files
Files changed (4) hide show
  1. README.md +40 -7
  2. app.py +123 -0
  3. cv_baseline_fold5_best.pt +3 -0
  4. requirements.txt +4 -0
README.md CHANGED
@@ -1,13 +1,46 @@
1
  ---
2
- title: Wound Classifier
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Chronic Wound Classifier
3
+ emoji: 🩹
4
+ colorFrom: red
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.0
 
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Chronic Wound Classifier 4-class AZH demo
14
+
15
+ Demo classifier for chronic wound photographs: predicts one of four wound types
16
+ (diabetic ulcer, pressure ulcer, surgical wound, venous ulcer) from an uploaded image.
17
+
18
+ **Not a medical device. Not for clinical use.** Research demonstration only.
19
+
20
+ ## Headline metric
21
+ Top-1 accuracy on the held-out AZH Test set (n=184): **0.8152**
22
+ (`cv_baseline_fold5_best.pt` — the highest single-fold checkpoint from
23
+ patient-grouped 10-fold cross-validation).
24
+
25
+ The 10-fold soft-vote ensemble of the same recipe scores 0.7989 on the same
26
+ set; the single-checkpoint variant is shipped here for inference latency
27
+ and footprint reasons.
28
+
29
+ ## Architecture
30
+ EfficientNet-B0 (ImageNet-pretrained), two-phase fine-tune (head-only 5 epochs
31
+ at lr=1e-3, then full network 15 epochs at lr=1e-4). Patient-grouped CV
32
+ splits ensure the same patient's images never appear in both train and val.
33
+
34
+ ## Limitations
35
+ - **Pressure-class accuracy is ~0.41** — interpret pressure-class predictions with care.
36
+ - No fairness audit across skin tones (known gap).
37
+ - English-only UI; no mobile or offline build.
38
+ - Not validated on real patient cohorts outside AZH.
39
+
40
+ ## Source code & training pipeline
41
+ The training, evaluation, and methodology code live in the project repo:
42
+ [github.com — wound-classification](#) (full link to be added by user)
43
+
44
+ ## Citation
45
+ Anisuzzaman et al. 2022. *Multi-modal wound classification using wound image
46
+ and location by deep neural network.* Sci. Rep. 12:20057.
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HF Spaces Gradio app — chronic wound classifier (4-class).
2
+
3
+ Self-contained: no wound_classifier package install required. The model
4
+ architecture and transforms are inlined here so this file plus the .pt
5
+ checkpoint and requirements.txt are everything the Space needs.
6
+
7
+ If the architecture or transform here drifts from
8
+ src/wound_classifier/{modeling/models.py, features.py} the Space and the
9
+ training pipeline will silently disagree. Keep them in sync.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from pathlib import Path
15
+
16
+ import gradio as gr
17
+ import torch
18
+ from PIL import Image
19
+ from torch import nn
20
+ from torchvision import transforms
21
+ from torchvision.models import efficientnet_b0
22
+
23
+ CKPT_PATH = Path(__file__).parent / "cv_baseline_fold5_best.pt"
24
+ IMAGE_SIZE = 224
25
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
26
+ IMAGENET_STD = (0.229, 0.224, 0.225)
27
+ CLASS_NAMES = {
28
+ "D": "Diabetic ulcer",
29
+ "P": "Pressure ulcer",
30
+ "S": "Surgical wound",
31
+ "V": "Venous ulcer",
32
+ }
33
+ IDX_TO_CLASS = ["D", "P", "S", "V"]
34
+
35
+
36
+ def _build_model(num_classes: int = 4) -> nn.Module:
37
+ model: nn.Module = efficientnet_b0(weights=None)
38
+ in_features = model.classifier[1].in_features # type: ignore[index, union-attr]
39
+ model.classifier = nn.Sequential(
40
+ nn.Dropout(p=0.2, inplace=True),
41
+ nn.Linear(in_features, num_classes),
42
+ )
43
+ return model
44
+
45
+
46
+ def _load_model(path: Path) -> nn.Module:
47
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
48
+ model = _build_model(num_classes=4)
49
+ model.load_state_dict(ckpt["state_dict"])
50
+ model.eval()
51
+ return model
52
+
53
+
54
+ def _build_transform() -> transforms.Compose:
55
+ return transforms.Compose(
56
+ [
57
+ transforms.Resize(IMAGE_SIZE),
58
+ transforms.CenterCrop(IMAGE_SIZE),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
61
+ ]
62
+ )
63
+
64
+
65
+ MODEL = _load_model(CKPT_PATH)
66
+ TRANSFORM = _build_transform()
67
+
68
+
69
+ def classify(image: Image.Image) -> dict[str, float]:
70
+ if image is None:
71
+ return {}
72
+ rgb = image.convert("RGB")
73
+ x = TRANSFORM(rgb).unsqueeze(0)
74
+ with torch.inference_mode():
75
+ logits = MODEL(x)
76
+ probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
77
+ return {CLASS_NAMES[IDX_TO_CLASS[i]]: float(probs[i]) for i in range(4)}
78
+
79
+
80
+ DESCRIPTION = """\
81
+ Chronic wound photo classifier — predicts which of four wound types an uploaded image shows
82
+ (diabetic ulcer, pressure ulcer, surgical wound, or venous ulcer).
83
+
84
+ **This is a research demo, not a medical device.** It does not diagnose, triage, or replace clinician
85
+ judgement. It was trained on the public AZH Chronic Wound Database (Anisuzzaman et al. 2022) using
86
+ EfficientNet-B0 with patient-grouped 10-fold cross-validation; the deployed checkpoint
87
+ (`cv_baseline_fold5_best.pt`) reaches **0.8152 top-1** on the held-out AZH Test set (n=184).
88
+ Pressure ulcers are the underperforming class (~0.41 accuracy) — interpret pressure-class predictions
89
+ with extra care.
90
+
91
+ Upload a wound photograph below to see the model's per-class probabilities.
92
+ """
93
+
94
+ ARTICLE = """\
95
+ **Methodology in one paragraph.** EfficientNet-B0 (ImageNet-pretrained) two-phase fine-tuning:
96
+ phase 1 freezes the backbone and trains the 4-class head for 5 epochs at lr=1e-3; phase 2 unfreezes
97
+ and fine-tunes the full network for 15 epochs at lr=1e-4. Patient-grouped 10-fold CV ensures the same
98
+ patient's images never appear in both train and val. Training and validation under random-stratified
99
+ splits would inflate accuracy by ~3–30pp on AZH (we measured this directly); we report the patient-
100
+ grouped numbers as the honest evaluation. The shipped checkpoint is the highest single fold on the
101
+ held-out Test set; a 10-fold soft-vote ensemble averages 0.7989 on the same set.
102
+
103
+ **Out of scope.** Not for clinical decision-making. No claim of diagnostic accuracy on real patient
104
+ cohorts. No fairness audit across skin tones (known gap). English-only UI. No mobile / offline build.
105
+
106
+ **Citations.** Anisuzzaman et al. 2022, *Multi-modal wound classification using wound image and
107
+ location by deep neural network*, Sci. Rep. 12:20057.
108
+ """
109
+
110
+
111
+ demo = gr.Interface(
112
+ fn=classify,
113
+ inputs=gr.Image(type="pil", label="Wound photograph"),
114
+ outputs=gr.Label(num_top_classes=4, label="Predicted wound type"),
115
+ title="Chronic Wound Classifier — 4-class AZH demo",
116
+ description=DESCRIPTION,
117
+ article=ARTICLE,
118
+ flagging_mode="never",
119
+ )
120
+
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()
cv_baseline_fold5_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f4b21f6204855827f33f2679be63d964d237a4d73a9940382ef4d49ed552e99
3
+ size 16355126
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=2.0
2
+ torchvision>=0.15
3
+ gradio>=4.0
4
+ pillow>=10.0