Spaces:
Sleeping
Sleeping
File size: 8,042 Bytes
750fc9e 52f2bd6 750fc9e 8fc30da 76aec82 8fc30da ef2f177 52f2bd6 32f6cb5 ef2f177 52f2bd6 750fc9e 52f2bd6 750fc9e 52f2bd6 750fc9e 52f2bd6 5712b28 707146a 8fc30da 52f2bd6 ef2f177 52f2bd6 8fc30da ef2f177 8fc30da 8c5a430 8fc30da 8c5a430 8fc30da 8c5a430 4032829 32f6cb5 750fc9e ef2f177 69dcdac ef2f177 8fc30da ef2f177 8c5a430 8fc30da 707146a ef2f177 750fc9e f887276 52f2bd6 ef2f177 f887276 ef2f177 8fc30da ef2f177 8fc30da 52f2bd6 ef2f177 52f2bd6 707146a 8fc30da 8c5a430 8fc30da 52f2bd6 8c5a430 707146a 8c5a430 707146a 8c5a430 707146a 8c5a430 707146a ef2f177 707146a 8c5a430 707146a ef2f177 8fc30da 52f2bd6 750fc9e 52f2bd6 750fc9e 52f2bd6 750fc9e 52f2bd6 750fc9e 52f2bd6 750fc9e 707146a 52f2bd6 8fc30da 52f2bd6 4032829 ef2f177 37a088c 52f2bd6 4032829 52f2bd6 59a1955 f887276 59a1955 f887276 59a1955 52f2bd6 4032829 52f2bd6 37a088c 52f2bd6 37a088c 750fc9e 4032829 37a088c f887276 ef2f177 37a088c 750fc9e ef2f177 37a088c 52f2bd6 37a088c 4032829 ef2f177 4032829 ef2f177 750fc9e 37a088c 76aec82 8fc30da 8c5a430 707146a 59a1955 4032829 a2e2fdb 69dcdac 32f6cb5 ef2f177 f887276 750fc9e 52f2bd6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | import os
import time
import json
import io
import re
import requests
import numpy as np
import gradio as gr
import torch
import torchxrayvision as xrv
from torchvision import transforms
from PIL import Image
from skimage.transform import resize as sk_resize
import matplotlib.pyplot as plt
from transformers import pipeline
from dotenv import load_dotenv
# ====================================
# Setup & Config
# ====================================
print("===== Application Startup =====")
load_dotenv()
MOONSHOT_API_URL = "https://api.moonshot.ai/v1/chat/completions"
MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY")
print(f"API Key loaded: {bool(MOONSHOT_API_KEY)}")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
MODEL.eval()
PATHOLOGIES = MODEL.pathologies
# Local fallback LLM
local_summarizer = pipeline("text2text-generation", model="t5-small")
# ====================================
# Imaging Agent (Chest X-ray)
# ====================================
def imaging_agent(image_path: str):
if not image_path:
return "No image provided.", None, None
try:
img = Image.open(image_path).convert("L")
arr = np.array(img).astype(np.float32)
if arr.max() > 1:
arr /= 255.0
arr = xrv.datasets.normalize(arr, 4096)
h, w = arr.shape
min_dim = min(h, w)
startx = w // 2 - (min_dim // 2)
starty = h // 2 - (min_dim // 2)
arr = arr[starty:starty + min_dim, startx:startx + min_dim]
arr = sk_resize(arr, (224, 224), preserve_range=True)
tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
with torch.no_grad():
preds = MODEL(tensor_img)[0]
probs = torch.sigmoid(preds).cpu().numpy().tolist()
focus_labels = ["Lung Opacity", "Mass", "Nodule"]
focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0]
heatmap = np.mean(fmap, axis=0)
heatmap = sk_resize(heatmap, arr.shape, preserve_range=True)
plt.figure(figsize=(4, 4))
plt.imshow(arr, cmap="gray")
plt.imshow(heatmap, cmap="jet", alpha=0.4)
plt.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close()
buf.seek(0)
heatmap_img = Image.open(buf)
lines = [f"{name}: {p * 100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
prob_text = "\n".join(lines)
return (
"🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines),
prob_text,
heatmap_img
)
except Exception as e:
return f"Imaging agent error: {e}", None, None
# ====================================
# Lab Agent (tumor markers)
# ====================================
CANCER_MARKERS = {
"psa": {"unit": "ng/mL", "high": 4},
"ca125": {"unit": "U/mL", "high": 35},
"afp": {"unit": "ng/mL", "high": 10},
}
def lab_agent(text: str):
if not text.strip():
return "No lab text provided."
results, flags = [], []
for line in text.splitlines():
m = re.findall(r'([a-z0-9]+)\s*[:=]?\s*([\d\.]+)', line.lower())
for label, val in m:
if label in CANCER_MARKERS:
v = float(val)
thr = CANCER_MARKERS[label]
status = "ok"
if v > thr["high"]:
status = "elevated"
flags.append(f"{label.upper()} high")
results.append(f"{label.upper()}: {v} {thr['unit']} → {status}")
if not results:
return "Could not parse tumor markers."
return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + (
"\nFlags: " + ", ".join(flags) if flags else "\nFlags: none"
)
# ====================================
# Moonshot + Local AI Coordinator
# ====================================
def moonshot_summary(prompt: str):
"""Calls Moonshot API with graceful fallback."""
headers = {
"Authorization": f"Bearer {MOONSHOT_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "moonshot-v1",
"messages": [
{"role": "system", "content": "You are a clinical research AI generating concise diagnostic summaries."},
{"role": "user", "content": prompt}
],
"temperature": 0.4
}
try:
time.sleep(2) # respect rate limits
response = requests.post(MOONSHOT_API_URL, headers=headers, json=payload, timeout=30)
if response.status_code == 429:
return "⚠️ Moonshot API rate limit reached. Please wait a few seconds and retry."
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"].strip()
except Exception as e:
print(f"Moonshot API error: {e}")
# Local fallback
local_output = local_summarizer(prompt, max_length=100)[0]['generated_text']
return f"🤖 Local AI Summary (Fallback)\n{local_output}"
def coordinator(imaging_txt, lab_txt):
base_prompt = f"Summarize the following findings for early cancer risk assessment:\n\n{imaging_txt}\n\n{lab_txt}"
llm_summary = moonshot_summary(base_prompt)
summary = f"📋 AI Coordinator Summary (LLM-generated)\n\n{llm_summary}\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
return summary
# ====================================
# Samples
# ====================================
SAMPLES = {
"Normal X-ray": "samples/sample_xray1.png",
"Suspicious X-ray": "samples/sample_xray2.png",
}
SAMPLE_TEXTS = {
"Lab Results": "samples/sample_labs.txt",
"MRI Report": "samples/sample_mri.txt",
"CT Report": "samples/sample_ct.txt",
}
# ====================================
# Runner
# ====================================
def run_all(image, labs):
txt, prob_text, heatmap = imaging_agent(image) if image else ("No image.", None, None)
lab = lab_agent(labs)
coord = coordinator(txt, lab)
return txt, prob_text, heatmap, lab, coord
# ====================================
# Gradio UI
# ====================================
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)")
gr.Markdown("Upload a chest X-ray or choose a demo sample. Paste tumor marker labs or MRI/CT reports.\n\n⚠️ Research demo only. Not for clinical use.")
with gr.Row():
with gr.Column():
sample_dropdown = gr.Dropdown(choices=list(SAMPLES.keys()), value="Normal X-ray", label="Select Sample X-ray")
img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
imaging_out = gr.Textbox(label="Imaging Agent Output")
imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6)
imaging_heatmap = gr.Image(label="Heatmap Overlay")
with gr.Column():
text_dropdown = gr.Dropdown(choices=list(SAMPLE_TEXTS.keys()), value="Lab Results", label="Select Sample Report")
lab_in = gr.Textbox(lines=6, label="Lab / Report Input")
lab_out = gr.Textbox(label="Lab Agent Output")
run_btn = gr.Button("Run Agents")
coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
def load_sample(choice):
return SAMPLES.get(choice, None)
def load_text(choice):
path = SAMPLE_TEXTS.get(choice, None)
if path and path.endswith(".txt"):
with open(path, "r") as f:
return f.read()
return ""
sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in)
run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out])
demo.launch()
|