import os import time import json import io import re import requests import numpy as np import gradio as gr import torch import torchxrayvision as xrv from torchvision import transforms from PIL import Image from skimage.transform import resize as sk_resize import matplotlib.pyplot as plt from transformers import pipeline from dotenv import load_dotenv # ==================================== # Setup & Config # ==================================== print("===== Application Startup =====") load_dotenv() MOONSHOT_API_URL = "https://api.moonshot.ai/v1/chat/completions" MOONSHOT_API_KEY = os.getenv("MOONSHOT_API_KEY") print(f"API Key loaded: {bool(MOONSHOT_API_KEY)}") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE) MODEL.eval() PATHOLOGIES = MODEL.pathologies # Local fallback LLM local_summarizer = pipeline("text2text-generation", model="t5-small") # ==================================== # Imaging Agent (Chest X-ray) # ==================================== def imaging_agent(image_path: str): if not image_path: return "No image provided.", None, None try: img = Image.open(image_path).convert("L") arr = np.array(img).astype(np.float32) if arr.max() > 1: arr /= 255.0 arr = xrv.datasets.normalize(arr, 4096) h, w = arr.shape min_dim = min(h, w) startx = w // 2 - (min_dim // 2) starty = h // 2 - (min_dim // 2) arr = arr[starty:starty + min_dim, startx:startx + min_dim] arr = sk_resize(arr, (224, 224), preserve_range=True) tensor_img = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0).to(DEVICE) with torch.no_grad(): preds = MODEL(tensor_img)[0] probs = torch.sigmoid(preds).cpu().numpy().tolist() focus_labels = ["Lung Opacity", "Mass", "Nodule"] focus = [(l, probs[PATHOLOGIES.index(l)]) for l in focus_labels if l in PATHOLOGIES] fmap = MODEL.features(tensor_img).detach().cpu().numpy()[0] heatmap = np.mean(fmap, axis=0) heatmap = sk_resize(heatmap, arr.shape, preserve_range=True) plt.figure(figsize=(4, 4)) plt.imshow(arr, cmap="gray") plt.imshow(heatmap, cmap="jet", alpha=0.4) plt.axis("off") buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close() buf.seek(0) heatmap_img = Image.open(buf) lines = [f"{name}: {p * 100:.1f}%" for name, p in sorted(focus, key=lambda x: x[1], reverse=True)] prob_text = "\n".join(lines) return ( "🖼️ Imaging Agent (Chest X-ray for cancer risk)\n" + "\n".join(lines), prob_text, heatmap_img ) except Exception as e: return f"Imaging agent error: {e}", None, None # ==================================== # Lab Agent (tumor markers) # ==================================== CANCER_MARKERS = { "psa": {"unit": "ng/mL", "high": 4}, "ca125": {"unit": "U/mL", "high": 35}, "afp": {"unit": "ng/mL", "high": 10}, } def lab_agent(text: str): if not text.strip(): return "No lab text provided." results, flags = [], [] for line in text.splitlines(): m = re.findall(r'([a-z0-9]+)\s*[:=]?\s*([\d\.]+)', line.lower()) for label, val in m: if label in CANCER_MARKERS: v = float(val) thr = CANCER_MARKERS[label] status = "ok" if v > thr["high"]: status = "elevated" flags.append(f"{label.upper()} high") results.append(f"{label.upper()}: {v} {thr['unit']} → {status}") if not results: return "Could not parse tumor markers." return "🧪 Lab Agent (Tumor Markers)\n" + "\n".join(results) + ( "\nFlags: " + ", ".join(flags) if flags else "\nFlags: none" ) # ==================================== # Moonshot + Local AI Coordinator # ==================================== def moonshot_summary(prompt: str): """Calls Moonshot API with graceful fallback.""" headers = { "Authorization": f"Bearer {MOONSHOT_API_KEY}", "Content-Type": "application/json" } payload = { "model": "moonshot-v1", "messages": [ {"role": "system", "content": "You are a clinical research AI generating concise diagnostic summaries."}, {"role": "user", "content": prompt} ], "temperature": 0.4 } try: time.sleep(2) # respect rate limits response = requests.post(MOONSHOT_API_URL, headers=headers, json=payload, timeout=30) if response.status_code == 429: return "⚠️ Moonshot API rate limit reached. Please wait a few seconds and retry." response.raise_for_status() return response.json()["choices"][0]["message"]["content"].strip() except Exception as e: print(f"Moonshot API error: {e}") # Local fallback local_output = local_summarizer(prompt, max_length=100)[0]['generated_text'] return f"🤖 Local AI Summary (Fallback)\n{local_output}" def coordinator(imaging_txt, lab_txt): base_prompt = f"Summarize the following findings for early cancer risk assessment:\n\n{imaging_txt}\n\n{lab_txt}" llm_summary = moonshot_summary(base_prompt) summary = f"📋 AI Coordinator Summary (LLM-generated)\n\n{llm_summary}\n\n⚠️ Disclaimer: Research demo only. Not for clinical use." return summary # ==================================== # Samples # ==================================== SAMPLES = { "Normal X-ray": "samples/sample_xray1.png", "Suspicious X-ray": "samples/sample_xray2.png", } SAMPLE_TEXTS = { "Lab Results": "samples/sample_labs.txt", "MRI Report": "samples/sample_mri.txt", "CT Report": "samples/sample_ct.txt", } # ==================================== # Runner # ==================================== def run_all(image, labs): txt, prob_text, heatmap = imaging_agent(image) if image else ("No image.", None, None) lab = lab_agent(labs) coord = coordinator(txt, lab) return txt, prob_text, heatmap, lab, coord # ==================================== # Gradio UI # ==================================== with gr.Blocks(theme="soft") as demo: gr.Markdown("# 🏥 AI Diagnostics Agent: Early Cancer Discovery (Demo)") gr.Markdown("Upload a chest X-ray or choose a demo sample. Paste tumor marker labs or MRI/CT reports.\n\n⚠️ Research demo only. Not for clinical use.") with gr.Row(): with gr.Column(): sample_dropdown = gr.Dropdown(choices=list(SAMPLES.keys()), value="Normal X-ray", label="Select Sample X-ray") img_in = gr.Image(type="filepath", label="Chest X-ray (PNG/JPG)") imaging_out = gr.Textbox(label="Imaging Agent Output") imaging_raw = gr.Textbox(label="Probabilities (%)", lines=6) imaging_heatmap = gr.Image(label="Heatmap Overlay") with gr.Column(): text_dropdown = gr.Dropdown(choices=list(SAMPLE_TEXTS.keys()), value="Lab Results", label="Select Sample Report") lab_in = gr.Textbox(lines=6, label="Lab / Report Input") lab_out = gr.Textbox(label="Lab Agent Output") run_btn = gr.Button("Run Agents") coord_out = gr.Textbox(label="Coordinator Summary", lines=10) def load_sample(choice): return SAMPLES.get(choice, None) def load_text(choice): path = SAMPLE_TEXTS.get(choice, None) if path and path.endswith(".txt"): with open(path, "r") as f: return f.read() return "" sample_dropdown.change(load_sample, inputs=sample_dropdown, outputs=img_in) text_dropdown.change(load_text, inputs=text_dropdown, outputs=lab_in) run_btn.click(run_all, inputs=[img_in, lab_in], outputs=[imaging_out, imaging_raw, imaging_heatmap, lab_out, coord_out]) demo.launch()