Benny-Tang commited on
Commit
707146a
·
verified ·
1 Parent(s): 4170577

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -53
app.py CHANGED
@@ -1,23 +1,21 @@
1
- import os
2
  import json
3
  import numpy as np
4
  import gradio as gr
5
  from PIL import Image
6
 
7
- # Imaging (Chest X-ray) — TorchXRayVision
8
  import torch
9
  import torchxrayvision as xrv
10
  from torchvision import transforms
11
 
12
  # -----------------------------
13
- # Imaging Agent (Chest X-ray)
14
  # -----------------------------
15
- _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- _MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(_DEVICE)
17
- _MODEL.eval()
18
- _PATHOLOGIES = _MODEL.pathologies
19
 
20
- _resize = transforms.Compose([
21
  xrv.datasets.XRayCenterCrop(),
22
  xrv.datasets.XRayResizer(224)
23
  ])
@@ -31,33 +29,30 @@ def imaging_agent(image_path: str):
31
  if arr.max() > 1:
32
  arr /= 255.0
33
  arr = xrv.datasets.normalize(arr, 4096)
34
- arr = _resize(arr)
35
- arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(_DEVICE)
36
 
37
  with torch.no_grad():
38
- preds = _MODEL(arr)[0]
39
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
40
 
41
- top = sorted(zip(_PATHOLOGIES, probs), key=lambda x: x[1], reverse=True)[:5]
42
- lines = [f"{name}: {p*100:.1f}%" for name, p in top]
43
- table = {name: round(p, 4) for name, p in zip(_PATHOLOGIES, probs)}
44
- return "🖼️ Imaging Agent (Chest X-ray)\n" + "\n".join(lines), json.dumps(table, indent=2)
 
 
 
45
  except Exception as e:
46
  return f"Imaging agent error: {e}", None
47
 
48
  # -----------------------------
49
- # Stubbed Signal Agent (ECG)
50
- # -----------------------------
51
- def signal_agent(file, sample_rate):
52
- return "💓 Signal Agent (stub): ECG analysis not enabled in demo."
53
-
54
- # -----------------------------
55
- # Lab Agent (simple parsing)
56
  # -----------------------------
57
- _THRESHOLDS = {
58
- "glucose": {"unit": "mg/dL", "high": 126, "low": 70},
59
- "hemoglobin": {"unit": "g/dL", "low": 12.0},
60
- "spo2": {"unit": "%", "low": 92},
61
  }
62
 
63
  import re
@@ -67,64 +62,57 @@ def lab_agent(text: str):
67
  results = []
68
  flags = []
69
  for line in text.splitlines():
70
- m = re.findall(r'([a-z]+)\s*[:=]?\s*([\d\.]+)', line.lower())
71
  for label, val in m:
72
- if label in _THRESHOLDS:
73
  v = float(val)
74
- thr = _THRESHOLDS[label]
75
  status = "ok"
76
- if "high" in thr and v > thr["high"]:
77
- status = "high"; flags.append(f"{label} high")
78
- if "low" in thr and v < thr["low"]:
79
- status = "low"; flags.append(f"{label} low")
80
- results.append(f"{label.capitalize()}: {v} {thr['unit']} → {status}")
81
  if not results:
82
- return "Could not parse labs."
83
- return "🧪 Lab Agent\n" + "\n".join(results) + ("\nFlags: " + ", ".join(flags) if flags else "\nFlags: none")
84
 
85
  # -----------------------------
86
  # Coordinator
87
  # -----------------------------
88
- def coordinator(imaging_txt, signal_txt, lab_txt):
89
- summary = "📋 Coordinator Summary\n"
90
  if imaging_txt: summary += "\n" + imaging_txt
91
- if signal_txt: summary += "\n" + signal_txt
92
  if lab_txt: summary += "\n" + lab_txt
93
- summary += "\n\n⚠️ Disclaimer: Demo only. Not for clinical use."
94
  return summary
95
 
96
  # -----------------------------
97
  # Gradio UI
98
  # -----------------------------
99
  with gr.Blocks(theme="soft") as demo:
100
- gr.Markdown("# 🏥 AI Diagnostics Agents (Demo)")
101
- gr.Markdown("Upload a chest X-ray or paste labs. ECG agent is stubbed for now.\n\n⚠️ Not for clinical use.")
102
 
103
  with gr.Row():
104
  with gr.Column():
105
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
106
- imaging_out = gr.Textbox(label="Imaging Agent")
107
  imaging_raw = gr.Code(label="Probabilities JSON", language="json")
108
  with gr.Column():
109
- ecg_in = gr.File(label="ECG / Biosignal (stubbed)")
110
- sr = gr.Number(label="Sampling Rate (Hz)", value=250)
111
- signal_out = gr.Textbox(label="Signal Agent Output")
112
- with gr.Column():
113
- lab_in = gr.Textbox(lines=6, label="Lab Results (e.g., 'glucose: 180')")
114
  lab_out = gr.Textbox(label="Lab Agent Output")
115
 
116
  run_btn = gr.Button("Run Agents")
117
  coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
118
 
119
- def run_all(image, ecg, rate, labs):
120
  txt, raw = imaging_agent(image) if image else ("No image.", None)
121
- sig = signal_agent(ecg, rate)
122
  lab = lab_agent(labs)
123
- coord = coordinator(txt, sig, lab)
124
- return txt, raw, sig, lab, coord
125
 
126
- run_btn.click(run_all, inputs=[img_in, ecg_in, sr, lab_in], outputs=[imaging_out, imaging_raw, signal_out, lab_out, coord_out])
127
 
128
  demo.launch()
129
 
130
 
 
 
 
1
  import json
2
  import numpy as np
3
  import gradio as gr
4
  from PIL import Image
5
 
 
6
  import torch
7
  import torchxrayvision as xrv
8
  from torchvision import transforms
9
 
10
  # -----------------------------
11
+ # Imaging Agent (Chest X-ray, proxy for lung cancer risk)
12
  # -----------------------------
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
15
+ MODEL.eval()
16
+ PATHOLOGIES = MODEL.pathologies
17
 
18
+ resize = transforms.Compose([
19
  xrv.datasets.XRayCenterCrop(),
20
  xrv.datasets.XRayResizer(224)
21
  ])
 
29
  if arr.max() > 1:
30
  arr /= 255.0
31
  arr = xrv.datasets.normalize(arr, 4096)
32
+ arr = resize(arr)
33
+ arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
34
 
35
  with torch.no_grad():
36
+ preds = MODEL(arr)[0]
37
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
38
 
39
+ # Focus on lung pathologies relevant to cancer
40
+ focus_labels = ["Lung Opacity", "Mass", "Nodule"]
41
+ focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
42
+
43
+ lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
44
+ table = {name: round(p, 4) for name, p in focus}
45
+ return "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), json.dumps(table, indent=2)
46
  except Exception as e:
47
  return f"Imaging agent error: {e}", None
48
 
49
  # -----------------------------
50
+ # Lab Agent (tumor markers, stub thresholds)
 
 
 
 
 
 
51
  # -----------------------------
52
+ CANCER_MARKERS = {
53
+ "psa": {"unit": "ng/mL", "high": 4},
54
+ "ca125": {"unit": "U/mL", "high": 35},
55
+ "afp": {"unit": "ng/mL", "high": 10},
56
  }
57
 
58
  import re
 
62
  results = []
63
  flags = []
64
  for line in text.splitlines():
65
+ m = re.findall(r'([a-z0-9]+)\s*[:=]?\s*([\d\.]+)', line.lower())
66
  for label, val in m:
67
+ if label in CANCER_MARKERS:
68
  v = float(val)
69
+ thr = CANCER_MARKERS[label]
70
  status = "ok"
71
+ if v > thr["high"]:
72
+ status = "elevated"; flags.append(f"{label.upper()} high")
73
+ results.append(f"{label.upper()}: {v} {thr['unit']} → {status}")
 
 
74
  if not results:
75
+ return "Could not parse tumor markers."
76
+ return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ("\nFlags: " + ", ".join(flags) if flags else "\nFlags: none")
77
 
78
  # -----------------------------
79
  # Coordinator
80
  # -----------------------------
81
+ def coordinator(imaging_txt, lab_txt):
82
+ summary = "📋 Coordinator Summary (Early Cancer Screening)\n"
83
  if imaging_txt: summary += "\n" + imaging_txt
 
84
  if lab_txt: summary += "\n" + lab_txt
85
+ summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
86
  return summary
87
 
88
  # -----------------------------
89
  # Gradio UI
90
  # -----------------------------
91
  with gr.Blocks(theme="soft") as demo:
92
+ gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)")
93
+ gr.Markdown("Upload a chest X-ray or paste tumor marker labs.\n\n⚠️ Research demo only. Not for clinical use.")
94
 
95
  with gr.Row():
96
  with gr.Column():
97
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
98
+ imaging_out = gr.Textbox(label="Imaging Agent Output")
99
  imaging_raw = gr.Code(label="Probabilities JSON", language="json")
100
  with gr.Column():
101
+ lab_in = gr.Textbox(lines=6, label="Lab Results (e.g., 'PSA: 8 ng/mL')")
 
 
 
 
102
  lab_out = gr.Textbox(label="Lab Agent Output")
103
 
104
  run_btn = gr.Button("Run Agents")
105
  coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
106
 
107
+ def run_all(image, labs):
108
  txt, raw = imaging_agent(image) if image else ("No image.", None)
 
109
  lab = lab_agent(labs)
110
+ coord = coordinator(txt, lab)
111
+ return txt, raw, lab, coord
112
 
113
+ run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, lab_out, coord_out])
114
 
115
  demo.launch()
116
 
117
 
118
+