Benny-Tang commited on
Commit
8fc30da
·
verified ·
1 Parent(s): 76aec82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -3
app.py CHANGED
@@ -1,7 +1,275 @@
 
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import math
5
+ import numpy as np
6
+ import pandas as pd
7
  import gradio as gr
8
+ from PIL import Image
9
 
10
+ # Imaging (Chest X-ray) — TorchXRayVision
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torchxrayvision as xrv
14
+ from torchvision import transforms
15
+
16
+ # ECG / Signals — HeartPy
17
+ import heartpy as hp
18
+
19
+ # Optional DICOM support
20
+ try:
21
+ import pydicom
22
+ HAS_PYDICOM = True
23
+ except Exception:
24
+ HAS_PYDICOM = False
25
+
26
+ # -----------------------------
27
+ # Imaging Agent (Chest X-ray)
28
+ # -----------------------------
29
+ _IMAGING_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ # Pretrained DenseNet on multiple datasets (free)
32
+ _IMAGING_MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(_IMAGING_DEVICE)
33
+ _IMAGING_MODEL.eval()
34
+ _PATHOLOGY_LIST = _IMAGING_MODEL.pathologies # list of labels
35
+
36
+ # Preprocess pipeline for CXRs
37
+ _img_resize = transforms.Compose([
38
+ xrv.datasets.XRayCenterCrop(),
39
+ xrv.datasets.XRayResizer(224)
40
+ ])
41
+
42
+ def _load_cxr_image(filepath: str) -> Image.Image:
43
+ # Accept: DICOM, jpg, png
44
+ ext = os.path.splitext(filepath)[1].lower()
45
+ if ext in [".dcm", ".dicom"] and HAS_PYDICOM:
46
+ ds = pydicom.dcmread(filepath)
47
+ arr = ds.pixel_array.astype(np.float32)
48
+ # Normalize and to PIL
49
+ arr = arr - arr.min()
50
+ if arr.max() > 0:
51
+ arr = arr / arr.max()
52
+ arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
53
+ return Image.fromarray(arr).convert("L")
54
+ else:
55
+ return Image.open(filepath).convert("L")
56
+
57
+ def imaging_agent(image_path: str):
58
+ if not image_path:
59
+ return "No image provided.", None
60
+
61
+ try:
62
+ img = _load_cxr_image(image_path)
63
+ # to numpy [1,1,H,W] normalized
64
+ arr = np.array(img).astype(np.float32)
65
+ if arr.max() > 1.0:
66
+ arr /= 255.0
67
+ arr = xrv.datasets.normalize(arr, 4096) # safe normalize
68
+ arr = _img_resize(arr)
69
+ arr = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(_IMAGING_DEVICE)
70
+
71
+ with torch.no_grad():
72
+ preds = _IMAGING_MODEL(arr)[0] # shape [num_labels]
73
+ probs = torch.sigmoid(preds).detach().cpu().numpy().tolist()
74
+
75
+ # Top 5 findings
76
+ top = sorted(zip(_PATHOLOGY_LIST, probs), key=lambda x: x[1], reverse=True)[:5]
77
+ lines = [f"{name}: {p*100:.1f}%" for name, p in top]
78
+ table = {name: round(p, 4) for name, p in zip(_PATHOLOGY_LIST, probs)}
79
+ return "🖼️ Imaging Agent (Chest X-ray)\n" + "\n".join(lines), json.dumps(table, indent=2)
80
+ except Exception as e:
81
+ return f"Imaging agent error: {e}", None
82
+
83
+ # -----------------------------
84
+ # Signal Agent (ECG)
85
+ # -----------------------------
86
+ def _read_signal(file_obj) -> np.ndarray:
87
+ # Expect CSV/TSV or plain text with one column
88
+ try:
89
+ file_obj.seek(0)
90
+ df = pd.read_csv(file_obj)
91
+ except Exception:
92
+ file_obj.seek(0)
93
+ try:
94
+ df = pd.read_csv(file_obj, sep="\t")
95
+ except Exception:
96
+ file_obj.seek(0)
97
+ # Plain text: one value per line
98
+ vals = [float(x.strip()) for x in file_obj.read().decode("utf-8").splitlines() if x.strip()]
99
+ return np.asarray(vals, dtype=np.float32)
100
+ # Prefer first numeric column
101
+ for col in df.columns:
102
+ if pd.api.types.is_numeric_dtype(df[col]):
103
+ return df[col].dropna().to_numpy(dtype=np.float32)
104
+ # Fallback: try to coerce first column
105
+ return pd.to_numeric(df.iloc[:, 0], errors="coerce").dropna().to_numpy(dtype=np.float32)
106
+
107
+ def signal_agent(file, sample_rate_hz: float):
108
+ if file is None:
109
+ return "No signal file uploaded."
110
+ if not sample_rate_hz or sample_rate_hz <= 0:
111
+ return "Please provide a valid sampling rate (Hz)."
112
+
113
+ try:
114
+ sig = _read_signal(file)
115
+ # Basic safety checks
116
+ if len(sig) < sample_rate_hz * 5:
117
+ return "Signal too short. Provide at least 5 seconds of data."
118
+ # HeartPy processing
119
+ wd, m = hp.process(sig, sample_rate=sample_rate_hz)
120
+ bpm = m.get('bpm', float('nan'))
121
+ rmssd = m.get('rmssd', float('nan'))
122
+ ibi = m.get('ibi', float('nan'))
123
+ sdnn = m.get('sdnn', float('nan'))
124
+
125
+ # Simple flags (not medical diagnoses)
126
+ flags = []
127
+ if not math.isnan(bpm):
128
+ if bpm < 50: flags.append("Bradycardia risk (low BPM)")
129
+ if bpm > 110: flags.append("Tachycardia risk (high BPM)")
130
+ if not math.isnan(sdnn) and sdnn > 100:
131
+ flags.append("High variability — possible irregular rhythm")
132
+ if not math.isnan(rmssd) and rmssd > 80:
133
+ flags.append("Elevated RMSSD — irregularity suspicion")
134
+
135
+ summary = [
136
+ "💓 Signal Agent (ECG-like biosignal)",
137
+ f"BPM: {bpm:.1f}" if not math.isnan(bpm) else "BPM: N/A",
138
+ f"SDNN: {sdnn:.1f} ms" if not math.isnan(sdnn) else "SDNN: N/A",
139
+ f"RMSSD: {rmssd:.1f} ms" if not math.isnan(rmssd) else "RMSSD: N/A",
140
+ f"IBI: {ibi:.1f} ms" if not math.isnan(ibi) else "IBI: N/A",
141
+ "Flags: " + (", ".join(flags) if flags else "none")
142
+ ]
143
+ return "\n".join(summary)
144
+ except Exception as e:
145
+ return f"Signal agent error: {e}"
146
+
147
+ # -----------------------------
148
+ # Lab Agent (text parsing)
149
+ # -----------------------------
150
+ # Simple thresholds (illustrative only; NOT medical advice)
151
+ _LAB_THRESHOLDS = {
152
+ "glucose": {"unit": "mg/dL", "high": 126, "low": 70},
153
+ "hemoglobin": {"unit": "g/dL", "low": 12.0},
154
+ "spo2": {"unit": "%", "low": 92},
155
+ "ldl": {"unit": "mg/dL", "high": 160},
156
+ "hdl": {"unit": "mg/dL", "low": 40},
157
+ "triglycerides": {"unit": "mg/dL", "high": 200"},
158
+ "creatinine": {"unit": "mg/dL", "high": 1.3},
159
+ }
160
+
161
+ import re
162
+ def _extract_labs(text: str):
163
+ # match patterns like "glucose: 180 mg/dL" or "glucose 180"
164
+ results = {}
165
+ for line in text.splitlines():
166
+ line_l = line.lower()
167
+ matches = re.findall(r'([a-z\%\d\/]+)\s*[:=]?\s*([\-]?\d+\.?\d*)\s*([a-z%\/]+)?', line_l)
168
+ for label, val, unit in matches:
169
+ label = label.strip().replace("%", "spo2") if label.strip() == "o2" else label.strip()
170
+ if label in _LAB_THRESHOLDS:
171
+ try:
172
+ v = float(val)
173
+ results[label] = {"value": v, "unit": unit or _LAB_THRESHOLDS[label]["unit"]}
174
+ except:
175
+ pass
176
+ return results
177
+
178
+ def lab_agent(lab_text: str):
179
+ if not lab_text or not lab_text.strip():
180
+ return "No lab text provided."
181
+ labs = _extract_labs(lab_text)
182
+ if not labs:
183
+ return "Could not parse lab values. Use lines like 'glucose: 180 mg/dL'."
184
+
185
+ lines = ["🧪 Lab Agent"]
186
+ flags = []
187
+ for k, v in labs.items():
188
+ value = v["value"]
189
+ unit = v["unit"]
190
+ thr = _LAB_THRESHOLDS.get(k, {})
191
+ status = "ok"
192
+ if "high" in thr and value > thr["high"]:
193
+ status = "high"
194
+ flags.append(f"{k} high")
195
+ if "low" in thr and value < thr["low"]:
196
+ status = "low"
197
+ flags.append(f"{k} low")
198
+ lines.append(f"{k.capitalize()}: {value} {unit} → {status}")
199
+ if flags:
200
+ lines.append("Flags: " + ", ".join(flags))
201
+ else:
202
+ lines.append("Flags: none")
203
+ return "\n".join(lines)
204
+
205
+ # -----------------------------
206
+ # Coordinator
207
+ # -----------------------------
208
+ def coordinator(imaging_txt: str, signal_txt: str, lab_txt: str):
209
+ parts = []
210
+ if imaging_txt and "Imaging Agent" in imaging_txt:
211
+ parts.append(imaging_txt)
212
+ if signal_txt and "Signal Agent" in signal_txt:
213
+ parts.append(signal_txt)
214
+ if lab_txt and "Lab Agent" in lab_txt:
215
+ parts.append(lab_txt)
216
+
217
+ assessment = []
218
+ if imaging_txt and ("pneumonia" in imaging_txt.lower() or "infiltration" in imaging_txt.lower()):
219
+ assessment.append("Possible pulmonary involvement")
220
+ if signal_txt and ("tachycardia" in signal_txt.lower() or "irregular" in signal_txt.lower()):
221
+ assessment.append("Cardiac rhythm irregularity risk")
222
+ if lab_txt and ("glucose" in lab_txt.lower() and "high" in lab_txt.lower()):
223
+ assessment.append("Hyperglycemia risk")
224
+
225
+ summary = "📋 Coordinator Summary\n"
226
+ summary += "\n\n".join(parts) if parts else "No agent outputs."
227
+ if assessment:
228
+ summary += "\n\n🧭 Integrated Assessment: " + "; ".join(assessment)
229
+ summary += "\n\n⚠️ Disclaimer: This demo is not a medical device. Do not use for diagnosis. Consult a qualified clinician."
230
+ return summary
231
+
232
+ # -----------------------------
233
+ # Gradio UI
234
+ # -----------------------------
235
+ with gr.Blocks(theme="soft") as demo:
236
+ gr.Markdown("# 🏥 AI Diagnostics Agents (Demo)")
237
+ gr.Markdown(
238
+ "Upload a chest X-ray (PNG/JPG/DICOM), an ECG-like signal (CSV/TSV or text with one value per line), "
239
+ "and/or paste lab values. Each agent analyzes its modality; the coordinator fuses results.\n\n"
240
+ "⚠️ **Not for clinical use.**"
241
+ )
242
+
243
+ with gr.Row():
244
+ with gr.Column():
245
+ img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG or DICOM)")
246
+ imaging_out = gr.Textbox(label="Imaging Agent (Top findings)")
247
+ imaging_raw = gr.Code(label="Imaging Probabilities (JSON)", language="json")
248
+ with gr.Column():
249
+ ecg_in = gr.File(label="ECG / Biosignal (CSV/TSV or txt)")
250
+ sr = gr.Number(label="Sampling Rate (Hz)", value=250, precision=1)
251
+ signal_out = gr.Textbox(label="Signal Agent Summary")
252
+ with gr.Column():
253
+ lab_in = gr.Textbox(lines=12, label="Lab Results (e.g., 'glucose: 180 mg/dL')")
254
+ lab_out = gr.Textbox(label="Lab Agent Summary")
255
+
256
+ run_btn = gr.Button("Run All Agents")
257
+ coord_out = gr.Textbox(label="Coordinator Summary", lines=14)
258
+
259
+ def run_imaging(image_path):
260
+ txt, raw = imaging_agent(image_path) if image_path else ("No image provided.", None)
261
+ return txt, raw
262
+
263
+ def run_signal(file, rate):
264
+ return signal_agent(file, rate)
265
+
266
+ def run_lab(text):
267
+ return lab_agent(text)
268
+
269
+ run_btn.click(run_imaging, inputs=img_in, outputs=[imaging_out, imaging_raw])\
270
+ .then(run_signal, inputs=[ecg_in, sr], outputs=signal_out)\
271
+ .then(run_lab, inputs=lab_in, outputs=lab_out)\
272
+ .then(coordinator, inputs=[imaging_out, signal_out, lab_out], outputs=coord_out)
273
 
 
274
  demo.launch()
275
+