mmarquezsa commited on
Commit
354bfe2
·
verified ·
1 Parent(s): 4f67156

Fix: lazy loading, disable TTA on CPU, error handling, remove double inference

Browse files
Files changed (1) hide show
  1. app.py +41 -42
app.py CHANGED
@@ -1,28 +1,31 @@
1
- """Gradio app for WoundNetB7 DFU Analysis — Hugging Face Spaces deployment.
2
-
3
- Launch locally: python app.py
4
- Deploy to HF: push this repo to a Hugging Face Space (GPU recommended).
5
- """
6
  import gradio as gr
7
  import numpy as np
8
  import cv2
9
  import json
10
- from pipeline import WoundNetB7Pipeline
 
 
 
 
11
 
12
- # Initialize pipeline (loads all models once)
13
- pipe = WoundNetB7Pipeline(models_dir="models", use_tta=True)
 
 
 
 
14
 
15
- CLASS_COLORS_RGB = {
16
- 0: (0, 0, 0),
17
- 1: (0, 255, 0), # Foot: green
18
- 2: (255, 165, 0), # Perilesion: orange
19
- 3: (255, 0, 0), # Ulcer: red
20
- }
21
 
22
- FITZ_COLORS = {
23
- "I": "#fef3c7", "II": "#fde68a", "III": "#fbbf24",
24
- "IV": "#b45309", "V": "#78350f", "VI": "#451a03",
25
- }
 
 
 
 
 
26
 
27
 
28
  def analyze_image(image):
@@ -30,29 +33,31 @@ def analyze_image(image):
30
  if image is None:
31
  return None, "Please upload an image.", "{}"
32
 
33
- # Gradio provides RGB numpy array
34
- img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
 
35
 
36
- # Run pipeline
37
- result = pipe.analyze(img_bgr, use_tta=True)
 
 
38
 
39
- # Create overlay visualization
40
- overlay = pipe.visualize(img_bgr, result)
41
 
42
- # Build text summary
43
- summary = result.summary()
 
44
 
45
- # Build JSON output
46
- json_out = json.dumps(result.to_dict(), indent=2, ensure_ascii=False)
47
 
48
- return overlay, summary, json_out
 
 
49
 
50
 
51
- # Build Gradio interface
52
- with gr.Blocks(
53
- title="WoundNetB7 DFU Analysis",
54
- theme=gr.themes.Soft(),
55
- ) as demo:
56
  gr.Markdown(
57
  """
58
  # WoundNetB7 — Diabetic Foot Ulcer Analysis
@@ -70,7 +75,6 @@ with gr.Blocks(
70
  with gr.Column(scale=1):
71
  input_image = gr.Image(label="DFU Image", type="numpy")
72
  analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
73
-
74
  with gr.Column(scale=1):
75
  output_overlay = gr.Image(label="Segmentation Overlay")
76
 
@@ -80,11 +84,7 @@ with gr.Blocks(
80
  with gr.Column(scale=1):
81
  output_json = gr.Code(label="JSON Output", language="json")
82
 
83
- analyze_btn.click(
84
- fn=analyze_image,
85
- inputs=[input_image],
86
- outputs=[output_overlay, output_text, output_json],
87
- )
88
 
89
  gr.Markdown(
90
  """
@@ -94,8 +94,7 @@ with gr.Blocks(
94
  **PWAT Items:** 3=Necrotic Type, 4=Necrotic Amount, 5=Granulation Type,
95
  6=Granulation Amount, 7=Edges, 8=Periulcer Skin (0=best, 4=worst)
96
 
97
- **Debiasing:** Scores adjusted by Fitzpatrick type to reduce skin-tone bias
98
- (18% max group gap reduction, p < 10^-27).
99
  """
100
  )
101
 
 
1
+ """Gradio app for WoundNetB7 DFU Analysis — Hugging Face Spaces deployment."""
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import cv2
5
  import json
6
+ import traceback
7
+
8
+ # Lazy loading — don't crash at import time
9
+ pipe = None
10
+
11
 
12
+ def get_pipeline():
13
+ global pipe
14
+ if pipe is None:
15
+ from pipeline import WoundNetB7Pipeline
16
+ pipe = WoundNetB7Pipeline(models_dir="models", use_tta=False)
17
+ return pipe
18
 
 
 
 
 
 
 
19
 
20
+ def create_overlay(img_rgb, classmap):
21
+ """Create segmentation overlay on RGB image."""
22
+ colors = {1: (0, 255, 0), 2: (255, 165, 0), 3: (255, 0, 0)}
23
+ overlay = img_rgb.astype(np.float32).copy()
24
+ for cid, color in colors.items():
25
+ mask = classmap == cid
26
+ if np.any(mask):
27
+ overlay[mask] = overlay[mask] * 0.5 + np.array(color, dtype=np.float32) * 0.5
28
+ return overlay.astype(np.uint8)
29
 
30
 
31
  def analyze_image(image):
 
33
  if image is None:
34
  return None, "Please upload an image.", "{}"
35
 
36
+ try:
37
+ pipeline = get_pipeline()
38
+ img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
39
+ result = pipeline.analyze(img_bgr, use_tta=False)
40
 
41
+ # Create overlay from the segmentation already done (no re-run)
42
+ from src.segmentation import segment, CLASS_NAMES
43
+ seg = segment(pipeline.seg_model, img_bgr, pipeline.device, use_tta=False)
44
+ classmap = seg["classmap"]
45
 
46
+ if classmap.shape[:2] != image.shape[:2]:
47
+ classmap = cv2.resize(classmap, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
48
 
49
+ overlay = create_overlay(image, classmap)
50
+ summary = result.summary()
51
+ json_out = json.dumps(result.to_dict(), indent=2, ensure_ascii=False)
52
 
53
+ return overlay, summary, json_out
 
54
 
55
+ except Exception as e:
56
+ error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
57
+ return None, error_msg, "{}"
58
 
59
 
60
+ with gr.Blocks(title="WoundNetB7 DFU Analysis", theme=gr.themes.Soft()) as demo:
 
 
 
 
61
  gr.Markdown(
62
  """
63
  # WoundNetB7 — Diabetic Foot Ulcer Analysis
 
75
  with gr.Column(scale=1):
76
  input_image = gr.Image(label="DFU Image", type="numpy")
77
  analyze_btn = gr.Button("Analyze", variant="primary", size="lg")
 
78
  with gr.Column(scale=1):
79
  output_overlay = gr.Image(label="Segmentation Overlay")
80
 
 
84
  with gr.Column(scale=1):
85
  output_json = gr.Code(label="JSON Output", language="json")
86
 
87
+ analyze_btn.click(fn=analyze_image, inputs=[input_image], outputs=[output_overlay, output_text, output_json])
 
 
 
 
88
 
89
  gr.Markdown(
90
  """
 
94
  **PWAT Items:** 3=Necrotic Type, 4=Necrotic Amount, 5=Granulation Type,
95
  6=Granulation Amount, 7=Edges, 8=Periulcer Skin (0=best, 4=worst)
96
 
97
+ **Debiasing:** Scores adjusted by Fitzpatrick type to reduce skin-tone bias.
 
98
  """
99
  )
100