Benny-Tang commited on
Commit
750fc9e
·
verified ·
1 Parent(s): 766411f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -38
app.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  import numpy as np
2
  import gradio as gr
3
  from PIL import Image
@@ -6,9 +11,13 @@ import torchxrayvision as xrv
6
  from torchvision import transforms
7
  from skimage.transform import resize as sk_resize
8
  import matplotlib.pyplot as plt
9
- import io
10
- import re
11
- import json
 
 
 
 
12
 
13
  # -----------------------------
14
  # Imaging Agent (Chest X-ray)
@@ -23,33 +32,32 @@ def imaging_agent(image_path: str):
23
  if not image_path:
24
  return "No image provided.", None, None
25
  try:
26
- # Load grayscale X-ray
27
  img = Image.open(image_path).convert("L")
28
  arr = np.array(img).astype(np.float32)
29
  if arr.max() > 1:
30
  arr /= 255.0
31
  arr = xrv.datasets.normalize(arr, 4096)
32
 
33
- # Manual center crop & resize
34
  h, w = arr.shape
35
  min_dim = min(h, w)
36
  startx = w // 2 - (min_dim // 2)
37
  starty = h // 2 - (min_dim // 2)
38
- arr = arr[starty:starty+min_dim, startx:startx+min_dim]
39
  arr = sk_resize(arr, (224, 224), preserve_range=True)
40
 
41
  tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
42
 
43
- # Inference
44
  with torch.no_grad():
45
  preds = MODEL(tensor_img)[0]
46
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
47
 
48
- # Focus on lung pathologies
49
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
50
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
51
 
52
- # Generate heatmap overlay (using feature maps)
53
  fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0]
54
  heatmap = np.mean(fmap, axis=0)
55
  heatmap = sk_resize(heatmap, arr.shape, preserve_range=True)
@@ -64,10 +72,9 @@ def imaging_agent(image_path: str):
64
  buf.seek(0)
65
  heatmap_img = Image.open(buf)
66
 
67
- # Format outputs
68
- lines = [f"{name}: {p*100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
69
  prob_text = "\n".join(lines)
70
-
71
  return (
72
  "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines),
73
  prob_text,
@@ -112,14 +119,47 @@ def lab_agent(text: str):
112
 
113
 
114
  # -----------------------------
115
- # Coordinator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # -----------------------------
117
  def coordinator(imaging_txt, lab_txt):
118
- summary = "📋 Coordinator Summary (Early Cancer Screening)\n"
119
- if imaging_txt:
120
- summary += "\n" + imaging_txt
121
- if lab_txt:
122
- summary += "\n" + lab_txt
 
 
 
 
 
 
 
 
 
123
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
124
  return summary
125
 
@@ -153,32 +193,24 @@ def run_all(image, labs):
153
  # -----------------------------
154
  with gr.Blocks(theme="soft") as demo:
155
  gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)")
156
- gr.Markdown("Upload a chest X-ray, or pick a demo sample. Paste tumor marker labs.\n\n⚠️ Research demo only. Not for clinical use.")
 
157
 
158
  with gr.Row():
159
  with gr.Column():
160
- sample_dropdown = gr.Dropdown(
161
- choices=list(SAMPLES.keys()),
162
- value="Normal X-ray",
163
- label="Select Sample X-ray"
164
- )
165
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
166
  imaging_out = gr.Textbox(label="Imaging Agent Output")
167
  imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6)
168
  imaging_heatmap = gr.Image(label="Heatmap Overlay")
169
  with gr.Column():
170
- text_dropdown = gr.Dropdown(
171
- choices=list(SAMPLE_TEXTS.keys()),
172
- value="Lab Results",
173
- label="Select Sample Report"
174
- )
175
  lab_in = gr.Textbox(lines=6, label="Lab / Report Input")
176
  lab_out = gr.Textbox(label="Lab Agent Output")
177
 
178
  run_btn = gr.Button("Run Agents")
179
- coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
180
 
181
- # Link dropdowns
182
  def load_sample(choice):
183
  return SAMPLES.get(choice, None)
184
 
@@ -191,13 +223,7 @@ with gr.Blocks(theme="soft") as demo:
191
 
192
  sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
193
  text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in)
194
-
195
- # Main button
196
- run_btn.click(
197
- run_all,
198
- inputs=[img_in, lab_in],
199
- outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out],
200
- )
201
 
202
  demo.launch()
203
 
@@ -210,3 +236,4 @@ demo.launch()
210
 
211
 
212
 
 
 
1
+ import os
2
+ import io
3
+ import re
4
+ import json
5
+ import requests
6
  import numpy as np
7
  import gradio as gr
8
  from PIL import Image
 
11
  from torchvision import transforms
12
  from skimage.transform import resize as sk_resize
13
  import matplotlib.pyplot as plt
14
+ from dotenv import load_dotenv
15
+
16
+ # -----------------------------
17
+ # Environment setup
18
+ # -----------------------------
19
+ load_dotenv()
20
+ MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY")
21
 
22
  # -----------------------------
23
  # Imaging Agent (Chest X-ray)
 
32
  if not image_path:
33
  return "No image provided.", None, None
34
  try:
 
35
  img = Image.open(image_path).convert("L")
36
  arr = np.array(img).astype(np.float32)
37
  if arr.max() > 1:
38
  arr /= 255.0
39
  arr = xrv.datasets.normalize(arr, 4096)
40
 
41
+ # Center crop and resize
42
  h, w = arr.shape
43
  min_dim = min(h, w)
44
  startx = w // 2 - (min_dim // 2)
45
  starty = h // 2 - (min_dim // 2)
46
+ arr = arr[starty:starty + min_dim, startx:startx + min_dim]
47
  arr = sk_resize(arr, (224, 224), preserve_range=True)
48
 
49
  tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
50
 
51
+ # Predictions
52
  with torch.no_grad():
53
  preds = MODEL(tensor_img)[0]
54
  probs = torch.sigmoid(preds).cpu().numpy().tolist()
55
 
56
+ # Focus on key lung pathologies
57
  focus_labels = ["Lung Opacity", "Mass", "Nodule"]
58
  focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
59
 
60
+ # Generate heatmap
61
  fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0]
62
  heatmap = np.mean(fmap, axis=0)
63
  heatmap = sk_resize(heatmap, arr.shape, preserve_range=True)
 
72
  buf.seek(0)
73
  heatmap_img = Image.open(buf)
74
 
75
+ # Output probabilities
76
+ lines = [f"{name}: {p * 100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
77
  prob_text = "\n".join(lines)
 
78
  return (
79
  "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines),
80
  prob_text,
 
119
 
120
 
121
  # -----------------------------
122
+ # Moonshot.ai API Integration
123
+ # -----------------------------
124
+ def moonshot_summary(prompt: str):
125
+ try:
126
+ url = "https://api.moonshot.cn/v1/chat/completions"
127
+ headers = {
128
+ "Authorization": f"Bearer {MOONSHOT_API_KEY}",
129
+ "Content-Type": "application/json"
130
+ }
131
+ payload = {
132
+ "model": "moonshot-v1",
133
+ "messages": [{"role": "user", "content": prompt}],
134
+ "temperature": 0.6,
135
+ "max_tokens": 600
136
+ }
137
+ response = requests.post(url, headers=headers, json=payload)
138
+ response.raise_for_status()
139
+ data = response.json()
140
+ return data["choices"][0]["message"]["content"].strip()
141
+ except Exception as e:
142
+ return f"Moonshot API error: {e}"
143
+
144
+
145
+ # -----------------------------
146
+ # Coordinator (LLM-enhanced)
147
  # -----------------------------
148
  def coordinator(imaging_txt, lab_txt):
149
+ prompt = f"""
150
+ You are an AI medical coordinator. Based on the imaging and lab findings below,
151
+ generate a clear, patient-friendly summary assessing potential cancer risk,
152
+ highlighting abnormal findings, and suggesting appropriate next steps.
153
+ Avoid technical jargon. Keep it concise and empathetic.
154
+
155
+ Imaging findings:
156
+ {imaging_txt}
157
+
158
+ Lab results:
159
+ {lab_txt}
160
+ """
161
+ ai_summary = moonshot_summary(prompt)
162
+ summary = "📋 AI Coordinator Summary (LLM-generated)\n\n" + ai_summary
163
  summary += "\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
164
  return summary
165
 
 
193
  # -----------------------------
194
  with gr.Blocks(theme="soft") as demo:
195
  gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)")
196
+ gr.Markdown("Upload a chest X-ray or choose a sample. Paste or load lab / MRI / CT reports. "
197
+ "\n\n⚠️ Research demo only. Not for clinical use.")
198
 
199
  with gr.Row():
200
  with gr.Column():
201
+ sample_dropdown = gr.Dropdown(choices=list(SAMPLES.keys()), value="Normal X-ray", label="Select Sample X-ray")
 
 
 
 
202
  img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
203
  imaging_out = gr.Textbox(label="Imaging Agent Output")
204
  imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6)
205
  imaging_heatmap = gr.Image(label="Heatmap Overlay")
206
  with gr.Column():
207
+ text_dropdown = gr.Dropdown(choices=list(SAMPLE_TEXTS.keys()), value="Lab Results", label="Select Sample Report")
 
 
 
 
208
  lab_in = gr.Textbox(lines=6, label="Lab / Report Input")
209
  lab_out = gr.Textbox(label="Lab Agent Output")
210
 
211
  run_btn = gr.Button("Run Agents")
212
+ coord_out = gr.Textbox(label="AI Coordinator Summary", lines=10)
213
 
 
214
  def load_sample(choice):
215
  return SAMPLES.get(choice, None)
216
 
 
223
 
224
  sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
225
  text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in)
226
+ run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out])
 
 
 
 
 
 
227
 
228
  demo.launch()
229
 
 
236
 
237
 
238
 
239
+