aidiagnostics / app.py
Benny-Tang's picture
Update app.py
52f2bd6 verified
import os
import time
import json
import io
import re
import requests
import numpy as np
import gradio as gr
import torch
import torchxrayvision as xrv
from torchvision import transforms
from PIL import Image
from skimage.transform import resize as sk_resize
import matplotlib.pyplot as plt
from transformers import pipeline
from dotenv import load_dotenv
# ====================================
# Setup & Config
# ====================================
print("===== Application Startup =====")
load_dotenv()
MOONSHOT_API_URL = "https://api.moonshot.ai/v1/chat/completions"
MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY")
print(f"API Key loaded: {bool(MOONSHOT_API_KEY)}")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE)
MODEL.eval()
PATHOLOGIES = MODEL.pathologies
# Local fallback LLM
local_summarizer = pipeline("text2text-generation", model="t5-small")
# ====================================
# Imaging Agent (Chest X-ray)
# ====================================
def imaging_agent(image_path: str):
if not image_path:
return "No image provided.", None, None
try:
img = Image.open(image_path).convert("L")
arr = np.array(img).astype(np.float32)
if arr.max() > 1:
arr /= 255.0
arr = xrv.datasets.normalize(arr, 4096)
h, w = arr.shape
min_dim = min(h, w)
startx = w // 2 - (min_dim // 2)
starty = h // 2 - (min_dim // 2)
arr = arr[starty:starty + min_dim, startx:startx + min_dim]
arr = sk_resize(arr, (224, 224), preserve_range=True)
tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)
with torch.no_grad():
preds = MODEL(tensor_img)[0]
probs = torch.sigmoid(preds).cpu().numpy().tolist()
focus_labels = ["Lung Opacity", "Mass", "Nodule"]
focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES]
fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0]
heatmap = np.mean(fmap, axis=0)
heatmap = sk_resize(heatmap, arr.shape, preserve_range=True)
plt.figure(figsize=(4, 4))
plt.imshow(arr, cmap="gray")
plt.imshow(heatmap, cmap="jet", alpha=0.4)
plt.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close()
buf.seek(0)
heatmap_img = Image.open(buf)
lines = [f"{name}: {p * 100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)]
prob_text = "\n".join(lines)
return (
"🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines),
prob_text,
heatmap_img
)
except Exception as e:
return f"Imaging agent error: {e}", None, None
# ====================================
# Lab Agent (tumor markers)
# ====================================
CANCER_MARKERS = {
"psa": {"unit": "ng/mL", "high": 4},
"ca125": {"unit": "U/mL", "high": 35},
"afp": {"unit": "ng/mL", "high": 10},
}
def lab_agent(text: str):
if not text.strip():
return "No lab text provided."
results, flags = [], []
for line in text.splitlines():
m = re.findall(r'([a-z0-9]+)\s*[:=]?\s*([\d\.]+)', line.lower())
for label, val in m:
if label in CANCER_MARKERS:
v = float(val)
thr = CANCER_MARKERS[label]
status = "ok"
if v > thr["high"]:
status = "elevated"
flags.append(f"{label.upper()} high")
results.append(f"{label.upper()}: {v} {thr['unit']}{status}")
if not results:
return "Could not parse tumor markers."
return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + (
"\nFlags: " + ", ".join(flags) if flags else "\nFlags: none"
)
# ====================================
# Moonshot + Local AI Coordinator
# ====================================
def moonshot_summary(prompt: str):
"""Calls Moonshot API with graceful fallback."""
headers = {
"Authorization": f"Bearer {MOONSHOT_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "moonshot-v1",
"messages": [
{"role": "system", "content": "You are a clinical research AI generating concise diagnostic summaries."},
{"role": "user", "content": prompt}
],
"temperature": 0.4
}
try:
time.sleep(2) # respect rate limits
response = requests.post(MOONSHOT_API_URL, headers=headers, json=payload, timeout=30)
if response.status_code == 429:
return "⚠️ Moonshot API rate limit reached. Please wait a few seconds and retry."
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"].strip()
except Exception as e:
print(f"Moonshot API error: {e}")
# Local fallback
local_output = local_summarizer(prompt, max_length=100)[0]['generated_text']
return f"🤖 Local AI Summary (Fallback)\n{local_output}"
def coordinator(imaging_txt, lab_txt):
base_prompt = f"Summarize the following findings for early cancer risk assessment:\n\n{imaging_txt}\n\n{lab_txt}"
llm_summary = moonshot_summary(base_prompt)
summary = f"📋 AI Coordinator Summary (LLM-generated)\n\n{llm_summary}\n\n⚠️ Disclaimer: Research demo only. Not for clinical use."
return summary
# ====================================
# Samples
# ====================================
SAMPLES = {
"Normal X-ray": "samples/sample_xray1.png",
"Suspicious X-ray": "samples/sample_xray2.png",
}
SAMPLE_TEXTS = {
"Lab Results": "samples/sample_labs.txt",
"MRI Report": "samples/sample_mri.txt",
"CT Report": "samples/sample_ct.txt",
}
# ====================================
# Runner
# ====================================
def run_all(image, labs):
txt, prob_text, heatmap = imaging_agent(image) if image else ("No image.", None, None)
lab = lab_agent(labs)
coord = coordinator(txt, lab)
return txt, prob_text, heatmap, lab, coord
# ====================================
# Gradio UI
# ====================================
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)")
gr.Markdown("Upload a chest X-ray or choose a demo sample. Paste tumor marker labs or MRI/CT reports.\n\n⚠️ Research demo only. Not for clinical use.")
with gr.Row():
with gr.Column():
sample_dropdown = gr.Dropdown(choices=list(SAMPLES.keys()), value="Normal X-ray", label="Select Sample X-ray")
img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)")
imaging_out = gr.Textbox(label="Imaging Agent Output")
imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6)
imaging_heatmap = gr.Image(label="Heatmap Overlay")
with gr.Column():
text_dropdown = gr.Dropdown(choices=list(SAMPLE_TEXTS.keys()), value="Lab Results", label="Select Sample Report")
lab_in = gr.Textbox(lines=6, label="Lab / Report Input")
lab_out = gr.Textbox(label="Lab Agent Output")
run_btn = gr.Button("Run Agents")
coord_out = gr.Textbox(label="Coordinator Summary", lines=10)
def load_sample(choice):
return SAMPLES.get(choice, None)
def load_text(choice):
path = SAMPLE_TEXTS.get(choice, None)
if path and path.endswith(".txt"):
with open(path, "r") as f:
return f.read()
return ""
sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in)
text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in)
run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out])
demo.launch()