Benny-Tang commited on
Commit
ef2f177
·
verified ·
1 Parent(s): 4614266

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -30
app.py CHANGED
@@ -2,69 +2,84 @@ import json
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image
5
-
6
  import torch
7
  import torchxrayvision as xrv
8
- import re
9
- import os
10
  from skimage.transform import resize as sk_resize
 
 
 
11
 
12
  # -----------------------------
13
- # Imaging Agent (Chest X-ray, proxy for lung cancer risk)
14
  # -----------------------------
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
17
  MODEL.eval()
18
  PATHOLOGIES = MODEL.pathologies
19
 
 
20
  def imaging_agent(image_path: str):
21
  if not image_path:
22
- return "No image provided.", None
23
  try:
24
  # Load grayscale X-ray
25
  img = Image.open(image_path).convert("L")
26
  arr = np.array(img).astype(np.float32)
27
-
28
- # Normalize to [0,1]
29
  if arr.max() > 1:
30
  arr /= 255.0
31
-
32
- # TorchXRayVision normalization
33
  arr = xrv.datasets.normalize(arr, 4096)
34
 
35
- # --- Manual preprocessing instead of dict transforms ---
36
  h, w = arr.shape
37
  min_dim = min(h, w)
38
  startx = w // 2 - (min_dim // 2)
39
  starty = h // 2 - (min_dim // 2)
40
- arr = arr[starty:starty+min_dim, startx:startx+min_dim] # center crop
41
-
42
- arr = sk_resize(arr, (224, 224), preserve_range=True) # resize to 224x224
43
 
44
- # Convert to tensor
45
- arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
46
 
47
  # Inference
48
  with torch.no_grad():
49
- preds = MODEL(arr)[0]
50
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
51
 
52
- # Focus on lung pathologies relevant to cancer
53
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
54
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Format outputs
57
  lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
58
  table = {name: round(p, 4) for name, p in focus}
59
 
60
- return "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), json.dumps(table, indent=2)
 
 
 
 
61
 
62
  except Exception as e:
63
- return f"Imaging agent error: {e}", None
64
 
65
 
66
  # -----------------------------
67
- # Lab Agent (tumor markers, thresholds stub)
68
  # -----------------------------
69
  CANCER_MARKERS = {
70
  "psa": {"unit": "ng/mL", "high": 4},
@@ -72,6 +87,7 @@ CANCER_MARKERS = {
72
  "afp": {"unit": "ng/mL", "high": 10},
73
  }
74
 
 
75
  def lab_agent(text: str):
76
  if not text.strip():
77
  return "No lab text provided."
@@ -85,11 +101,14 @@ def lab_agent(text: str):
85
  thr = CANCER_MARKERS[label]
86
  status = "ok"
87
  if v > thr["high"]:
88
- status = "elevated"; flags.append(f"{label.upper()} high")
 
89
  results.append(f"{label.upper()}: {v} {thr['unit']} → {status}")
90
  if not results:
91
  return "Could not parse tumor markers."
92
- return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ("\nFlags: " + ", ".join(flags) if flags else "\nFlags: none")
 
 
93
 
94
 
95
  # -----------------------------
@@ -97,8 +116,10 @@ def lab_agent(text: str):
97
  # -----------------------------
98
  def coordinator(imaging_txt, lab_txt):
99
  summary = "📋 Coordinator Summary (Early Cancer Screening)\n"
100
- if imaging_txt: summary += "\n" + imaging_txt
101
- if lab_txt: summary += "\n" + lab_txt
 
 
102
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
103
  return summary
104
 
@@ -110,17 +131,21 @@ SAMPLES = {
110
  "Normal X-ray": "samples/sample_xray1.png",
111
  "Suspicious X-ray": "samples/sample_xray2.png",
112
  }
113
- SAMPLE_LABS = "PSA: 8 ng/mL\nCA125: 20 U/mL\nAFP: 15 ng/mL"
 
 
 
 
114
 
115
 
116
  # -----------------------------
117
  # Runner
118
  # -----------------------------
119
  def run_all(image, labs):
120
- txt, raw = imaging_agent(image) if image else ("No image.", None)
121
  lab = lab_agent(labs)
122
  coord = coordinator(txt, lab)
123
- return txt, raw, lab, coord
124
 
125
 
126
  # -----------------------------
@@ -140,20 +165,39 @@ with gr.Blocks(theme="soft") as demo:
140
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
141
  imaging_out = gr.Textbox(label="Imaging Agent Output")
142
  imaging_raw = gr.Code(label="Probabilities JSON", language="json")
 
143
  with gr.Column():
144
- lab_in = gr.Textbox(lines=6, label="Lab Results", value=SAMPLE_LABS)
 
 
 
 
 
145
  lab_out = gr.Textbox(label="Lab Agent Output")
146
 
147
  run_btn = gr.Button("Run Agents")
148
  coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
149
 
150
- # Link dropdown to image input
151
  def load_sample(choice):
152
  return SAMPLES.get(choice, None)
 
 
 
 
 
 
 
 
153
  sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
 
154
 
155
  # Main button
156
- run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, lab_out, coord_out])
 
 
 
 
157
 
158
  demo.launch()
159
 
@@ -164,3 +208,4 @@ demo.launch()
164
 
165
 
166
 
 
 
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image
 
5
  import torch
6
  import torchxrayvision as xrv
7
+ from torchvision import transforms
 
8
  from skimage.transform import resize as sk_resize
9
+ import matplotlib.pyplot as plt
10
+ import io
11
+ import re
12
 
13
  # -----------------------------
14
+ # Imaging Agent (Chest X-ray)
15
  # -----------------------------
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
18
  MODEL.eval()
19
  PATHOLOGIES = MODEL.pathologies
20
 
21
+
22
  def imaging_agent(image_path: str):
23
  if not image_path:
24
+ return "No image provided.", None, None
25
  try:
26
  # Load grayscale X-ray
27
  img = Image.open(image_path).convert("L")
28
  arr = np.array(img).astype(np.float32)
 
 
29
  if arr.max() > 1:
30
  arr /= 255.0
 
 
31
  arr = xrv.datasets.normalize(arr, 4096)
32
 
33
+ # Manual center crop & resize
34
  h, w = arr.shape
35
  min_dim = min(h, w)
36
  startx = w // 2 - (min_dim // 2)
37
  starty = h // 2 - (min_dim // 2)
38
+ arr = arr[starty:starty+min_dim, startx:startx+min_dim]
39
+ arr = sk_resize(arr, (224, 224), preserve_range=True)
 
40
 
41
+ tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
 
42
 
43
  # Inference
44
  with torch.no_grad():
45
+ preds = MODEL(tensor_img)[0]
46
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
47
 
48
+ # Focus on lung pathologies
49
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
50
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
51
 
52
+ # Generate heatmap overlay (using feature maps)
53
+ fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0]
54
+ heatmap = np.mean(fmap, axis=0)
55
+ heatmap = sk_resize(heatmap, arr.shape, preserve_range=True)
56
+
57
+ plt.figure(figsize=(4, 4))
58
+ plt.imshow(arr, cmap="gray")
59
+ plt.imshow(heatmap, cmap="jet", alpha=0.4)
60
+ plt.axis("off")
61
+ buf = io.BytesIO()
62
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
63
+ plt.close()
64
+ buf.seek(0)
65
+ heatmap_img = Image.open(buf)
66
+
67
  # Format outputs
68
  lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
69
  table = {name: round(p, 4) for name, p in focus}
70
 
71
+ return (
72
+ "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines),
73
+ json.dumps(table, indent=2),
74
+ heatmap_img
75
+ )
76
 
77
  except Exception as e:
78
+ return f"Imaging agent error: {e}", None, None
79
 
80
 
81
  # -----------------------------
82
+ # Lab Agent (tumor markers)
83
  # -----------------------------
84
  CANCER_MARKERS = {
85
  "psa": {"unit": "ng/mL", "high": 4},
 
87
  "afp": {"unit": "ng/mL", "high": 10},
88
  }
89
 
90
+
91
  def lab_agent(text: str):
92
  if not text.strip():
93
  return "No lab text provided."
 
101
  thr = CANCER_MARKERS[label]
102
  status = "ok"
103
  if v > thr["high"]:
104
+ status = "elevated"
105
+ flags.append(f"{label.upper()} high")
106
  results.append(f"{label.upper()}: {v} {thr['unit']} → {status}")
107
  if not results:
108
  return "Could not parse tumor markers."
109
+ return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + (
110
+ "\nFlags: " + ", ".join(flags) if flags else "\nFlags: none"
111
+ )
112
 
113
 
114
  # -----------------------------
 
116
  # -----------------------------
117
  def coordinator(imaging_txt, lab_txt):
118
  summary = "📋 Coordinator Summary (Early Cancer Screening)\n"
119
+ if imaging_txt:
120
+ summary += "\n" + imaging_txt
121
+ if lab_txt:
122
+ summary += "\n" + lab_txt
123
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
124
  return summary
125
 
 
131
  "Normal X-ray": "samples/sample_xray1.png",
132
  "Suspicious X-ray": "samples/sample_xray2.png",
133
  }
134
+ SAMPLE_TEXTS = {
135
+ "Lab Results": "samples/sample_labs.txt",
136
+ "MRI Report": "samples/sample_mri.txt",
137
+ "CT Report": "samples/sample_ct.txt",
138
+ }
139
 
140
 
141
  # -----------------------------
142
  # Runner
143
  # -----------------------------
144
  def run_all(image, labs):
145
+ txt, raw, heatmap = imaging_agent(image) if image else ("No image.", None, None)
146
  lab = lab_agent(labs)
147
  coord = coordinator(txt, lab)
148
+ return txt, raw, heatmap, lab, coord
149
 
150
 
151
  # -----------------------------
 
165
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
166
  imaging_out = gr.Textbox(label="Imaging Agent Output")
167
  imaging_raw = gr.Code(label="Probabilities JSON", language="json")
168
+ imaging_heatmap = gr.Image(label="Heatmap Overlay")
169
  with gr.Column():
170
+ text_dropdown = gr.Dropdown(
171
+ choices=list(SAMPLE_TEXTS.keys()),
172
+ value="Lab Results",
173
+ label="Select Sample Report"
174
+ )
175
+ lab_in = gr.Textbox(lines=6, label="Lab / Report Input")
176
  lab_out = gr.Textbox(label="Lab Agent Output")
177
 
178
  run_btn = gr.Button("Run Agents")
179
  coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
180
 
181
+ # Link dropdowns
182
  def load_sample(choice):
183
  return SAMPLES.get(choice, None)
184
+
185
+ def load_text(choice):
186
+ path = SAMPLE_TEXTS.get(choice, None)
187
+ if path and path.endswith(".txt"):
188
+ with open(path, "r") as f:
189
+ return f.read()
190
+ return ""
191
+
192
  sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
193
+ text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in)
194
 
195
  # Main button
196
+ run_btn.click(
197
+ run_all,
198
+ inputs=[img_in, lab_in],
199
+ outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out],
200
+ )
201
 
202
  demo.launch()
203
 
 
208
 
209
 
210
 
211
+