jbobym commited on
Commit
d8c7fbf
·
verified ·
1 Parent(s): f53bdaa

Add user guidance: pre-upload tips, low-confidence warning, pressure-class skepticism note

Browse files
Files changed (1) hide show
  1. app.py +42 -12
app.py CHANGED
@@ -66,29 +66,56 @@ MODEL = _load_model(CKPT_PATH)
66
  TRANSFORM = _build_transform()
67
 
68
 
69
- def classify(image: Image.Image) -> dict[str, float]:
 
 
 
70
  if image is None:
71
- return {}
72
  rgb = image.convert("RGB")
73
  x = TRANSFORM(rgb).unsqueeze(0)
74
  with torch.inference_mode():
75
  logits = MODEL(x)
76
  probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
77
- return {CLASS_NAMES[IDX_TO_CLASS[i]]: float(probs[i]) for i in range(4)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  DESCRIPTION = """\
81
- Chronic wound photo classifier predicts which of four wound types an uploaded image shows
82
- (diabetic ulcer, pressure ulcer, surgical wound, or venous ulcer).
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  **This is a research demo, not a medical device.** It does not diagnose, triage, or replace clinician
85
- judgement. It was trained on the public AZH Chronic Wound Database (Anisuzzaman et al. 2022) using
86
  EfficientNet-B0 with patient-grouped 10-fold cross-validation; the deployed checkpoint
87
  (`cv_baseline_fold5_best.pt`) reaches **0.8152 top-1** on the held-out AZH Test set (n=184).
88
- Pressure ulcers are the underperforming class (~0.41 accuracy) — interpret pressure-class predictions
89
- with extra care.
90
-
91
- Upload a wound photograph below to see the model's per-class probabilities.
92
  """
93
 
94
  ARTICLE = """\
@@ -110,8 +137,11 @@ location by deep neural network*, Sci. Rep. 12:20057.
110
 
111
  demo = gr.Interface(
112
  fn=classify,
113
- inputs=gr.Image(type="pil", label="Wound photograph"),
114
- outputs=gr.Label(num_top_classes=4, label="Predicted wound type"),
 
 
 
115
  title="Chronic Wound Classifier — 4-class AZH demo",
116
  description=DESCRIPTION,
117
  article=ARTICLE,
 
66
  TRANSFORM = _build_transform()
67
 
68
 
69
+ LOW_CONFIDENCE_THRESHOLD = 0.5
70
+
71
+
72
+ def classify(image: Image.Image) -> tuple[dict[str, float], str]:
73
  if image is None:
74
+ return {}, ""
75
  rgb = image.convert("RGB")
76
  x = TRANSFORM(rgb).unsqueeze(0)
77
  with torch.inference_mode():
78
  logits = MODEL(x)
79
  probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
80
+ label_probs = {CLASS_NAMES[IDX_TO_CLASS[i]]: float(probs[i]) for i in range(4)}
81
+
82
+ top_label, top_prob = max(label_probs.items(), key=lambda kv: kv[1])
83
+ notes: list[str] = []
84
+ if top_prob < LOW_CONFIDENCE_THRESHOLD:
85
+ notes.append(
86
+ f"⚠️ **Low confidence** (top class {top_label} at {top_prob:.0%}). "
87
+ "This often means the photo isn't a clear close-up of a wound, or the wound type "
88
+ "isn't one of the four the model was trained on. The model has no \"not a wound\" "
89
+ "option — it will always pick one of D / P / S / V even when the image isn't a wound."
90
+ )
91
+ if top_label == "Pressure ulcer":
92
+ notes.append(
93
+ "⚠️ **Pressure-class predictions are the model's weak spot** "
94
+ "(~0.41 accuracy on the held-out test set). Treat this prediction with extra skepticism."
95
+ )
96
+ return label_probs, "\n\n".join(notes)
97
 
98
 
99
  DESCRIPTION = """\
100
+ Upload a close-up photo of a chronic wound and the model returns its best guess at the wound type
101
+ (diabetic ulcer, pressure ulcer, surgical wound, or venous ulcer) with per-class probabilities.
102
+
103
+ **Before you upload — a few things worth knowing:**
104
+
105
+ - **Frame the wound in the center.** The model resizes and center-crops to 224×224 pixels, so a wound
106
+ off in the corner of a high-resolution phone photo will get cropped out. Close-up clinical-style
107
+ framing (the wound fills most of the frame) gives the cleanest predictions.
108
+ - **JPEG or PNG works fine.** Anything PIL can open.
109
+ - **There is no "not a wound" option.** Upload a face, a sandwich, a sunset — the model will
110
+ confidently place it into one of the four wound classes anyway. The low-confidence warning below
111
+ is the only signal you'll get that the image probably isn't what the model was trained on.
112
+ - **Pressure-ulcer predictions are weakest** (~41% accuracy on the held-out test set). When the model
113
+ says "Pressure ulcer," treat the call with extra skepticism.
114
 
115
  **This is a research demo, not a medical device.** It does not diagnose, triage, or replace clinician
116
+ judgement. Trained on the public AZH Chronic Wound Database (Anisuzzaman et al. 2022) using
117
  EfficientNet-B0 with patient-grouped 10-fold cross-validation; the deployed checkpoint
118
  (`cv_baseline_fold5_best.pt`) reaches **0.8152 top-1** on the held-out AZH Test set (n=184).
 
 
 
 
119
  """
120
 
121
  ARTICLE = """\
 
137
 
138
  demo = gr.Interface(
139
  fn=classify,
140
+ inputs=gr.Image(type="pil", label="Wound photograph (close-up, centered)"),
141
+ outputs=[
142
+ gr.Label(num_top_classes=4, label="Predicted wound type"),
143
+ gr.Markdown(label="Notes"),
144
+ ],
145
  title="Chronic Wound Classifier — 4-class AZH demo",
146
  description=DESCRIPTION,
147
  article=ARTICLE,