Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import json | |
| import io | |
| import re | |
| import requests | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import torchxrayvision as xrv | |
| from torchvision import transforms | |
| from PIL import Image | |
| from skimage.transform import resize as sk_resize | |
| import matplotlib.pyplot as plt | |
| from transformers import pipeline | |
| from dotenv import load_dotenv | |
| # ==================================== | |
| # Setup & Config | |
| # ==================================== | |
| print("===== Application Startup =====") | |
| load_dotenv() | |
| MOONSHOT_API_URL = "https://api.moonshot.ai/v1/chat/completions" | |
| MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY") | |
| print(f"API Key loaded: {bool(MOONSHOT_API_KEY)}") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE) | |
| MODEL.eval() | |
| PATHOLOGIES = MODEL.pathologies | |
| # Local fallback LLM | |
| local_summarizer = pipeline("text2text-generation", model="t5-small") | |
| # ==================================== | |
| # Imaging Agent (Chest X-ray) | |
| # ==================================== | |
| def imaging_agent(image_path: str): | |
| if not image_path: | |
| return "No image provided.", None, None | |
| try: | |
| img = Image.open(image_path).convert("L") | |
| arr = np.array(img).astype(np.float32) | |
| if arr.max() > 1: | |
| arr /= 255.0 | |
| arr = xrv.datasets.normalize(arr, 4096) | |
| h, w = arr.shape | |
| min_dim = min(h, w) | |
| startx = w // 2 - (min_dim // 2) | |
| starty = h // 2 - (min_dim // 2) | |
| arr = arr[starty:starty + min_dim, startx:startx + min_dim] | |
| arr = sk_resize(arr, (224, 224), preserve_range=True) | |
| tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| preds = MODEL(tensor_img)[0] | |
| probs = torch.sigmoid(preds).cpu().numpy().tolist() | |
| focus_labels = ["Lung Opacity", "Mass", "Nodule"] | |
| focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES] | |
| fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0] | |
| heatmap = np.mean(fmap, axis=0) | |
| heatmap = sk_resize(heatmap, arr.shape, preserve_range=True) | |
| plt.figure(figsize=(4, 4)) | |
| plt.imshow(arr, cmap="gray") | |
| plt.imshow(heatmap, cmap="jet", alpha=0.4) | |
| plt.axis("off") | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close() | |
| buf.seek(0) | |
| heatmap_img = Image.open(buf) | |
| lines = [f"{name}: {p * 100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)] | |
| prob_text = "\n".join(lines) | |
| return ( | |
| "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), | |
| prob_text, | |
| heatmap_img | |
| ) | |
| except Exception as e: | |
| return f"Imaging agent error: {e}", None, None | |
| # ==================================== | |
| # Lab Agent (tumor markers) | |
| # ==================================== | |
| CANCER_MARKERS = { | |
| "psa": {"unit": "ng/mL", "high": 4}, | |
| "ca125": {"unit": "U/mL", "high": 35}, | |
| "afp": {"unit": "ng/mL", "high": 10}, | |
| } | |
| def lab_agent(text: str): | |
| if not text.strip(): | |
| return "No lab text provided." | |
| results, flags = [], [] | |
| for line in text.splitlines(): | |
| m = re.findall(r'([a-z0-9]+)\s*[:=]?\s*([\d\.]+)', line.lower()) | |
| for label, val in m: | |
| if label in CANCER_MARKERS: | |
| v = float(val) | |
| thr = CANCER_MARKERS[label] | |
| status = "ok" | |
| if v > thr["high"]: | |
| status = "elevated" | |
| flags.append(f"{label.upper()} high") | |
| results.append(f"{label.upper()}: {v} {thr['unit']} → {status}") | |
| if not results: | |
| return "Could not parse tumor markers." | |
| return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ( | |
| "\nFlags: " + ", ".join(flags) if flags else "\nFlags: none" | |
| ) | |
| # ==================================== | |
| # Moonshot + Local AI Coordinator | |
| # ==================================== | |
| def moonshot_summary(prompt: str): | |
| """Calls Moonshot API with graceful fallback.""" | |
| headers = { | |
| "Authorization": f"Bearer {MOONSHOT_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "moonshot-v1", | |
| "messages": [ | |
| {"role": "system", "content": "You are a clinical research AI generating concise diagnostic summaries."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "temperature": 0.4 | |
| } | |
| try: | |
| time.sleep(2) # respect rate limits | |
| response = requests.post(MOONSHOT_API_URL, headers=headers, json=payload, timeout=30) | |
| if response.status_code == 429: | |
| return "⚠️ Moonshot API rate limit reached. Please wait a few seconds and retry." | |
| response.raise_for_status() | |
| return response.json()["choices"][0]["message"]["content"].strip() | |
| except Exception as e: | |
| print(f"Moonshot API error: {e}") | |
| # Local fallback | |
| local_output = local_summarizer(prompt, max_length=100)[0]['generated_text'] | |
| return f"🤖 Local AI Summary (Fallback)\n{local_output}" | |
| def coordinator(imaging_txt, lab_txt): | |
| base_prompt = f"Summarize the following findings for early cancer risk assessment:\n\n{imaging_txt}\n\n{lab_txt}" | |
| llm_summary = moonshot_summary(base_prompt) | |
| summary = f"📋 AI Coordinator Summary (LLM-generated)\n\n{llm_summary}\n\n⚠️ Disclaimer: Research demo only. Not for clinical use." | |
| return summary | |
| # ==================================== | |
| # Samples | |
| # ==================================== | |
| SAMPLES = { | |
| "Normal X-ray": "samples/sample_xray1.png", | |
| "Suspicious X-ray": "samples/sample_xray2.png", | |
| } | |
| SAMPLE_TEXTS = { | |
| "Lab Results": "samples/sample_labs.txt", | |
| "MRI Report": "samples/sample_mri.txt", | |
| "CT Report": "samples/sample_ct.txt", | |
| } | |
| # ==================================== | |
| # Runner | |
| # ==================================== | |
| def run_all(image, labs): | |
| txt, prob_text, heatmap = imaging_agent(image) if image else ("No image.", None, None) | |
| lab = lab_agent(labs) | |
| coord = coordinator(txt, lab) | |
| return txt, prob_text, heatmap, lab, coord | |
| # ==================================== | |
| # Gradio UI | |
| # ==================================== | |
| with gr.Blocks(theme="soft") as demo: | |
| gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)") | |
| gr.Markdown("Upload a chest X-ray or choose a demo sample. Paste tumor marker labs or MRI/CT reports.\n\n⚠️ Research demo only. Not for clinical use.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sample_dropdown = gr.Dropdown(choices=list(SAMPLES.keys()), value="Normal X-ray", label="Select Sample X-ray") | |
| img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)") | |
| imaging_out = gr.Textbox(label="Imaging Agent Output") | |
| imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6) | |
| imaging_heatmap = gr.Image(label="Heatmap Overlay") | |
| with gr.Column(): | |
| text_dropdown = gr.Dropdown(choices=list(SAMPLE_TEXTS.keys()), value="Lab Results", label="Select Sample Report") | |
| lab_in = gr.Textbox(lines=6, label="Lab / Report Input") | |
| lab_out = gr.Textbox(label="Lab Agent Output") | |
| run_btn = gr.Button("Run Agents") | |
| coord_out = gr.Textbox(label="Coordinator Summary", lines=10) | |
| def load_sample(choice): | |
| return SAMPLES.get(choice, None) | |
| def load_text(choice): | |
| path = SAMPLE_TEXTS.get(choice, None) | |
| if path and path.endswith(".txt"): | |
| with open(path, "r") as f: | |
| return f.read() | |
| return "" | |
| sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in) | |
| text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in) | |
| run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out]) | |
| demo.launch() | |