""" BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI. Full pipeline: Adversarial GCN + Qwen2.5-7B fine-tuned on AMD MI300X. """ from __future__ import annotations import io from pathlib import Path import numpy as np import torch import gradio as gr _WINDOW_LEN = 50 _STEP = 3 _MAX_WINDOWS = 30 _FC_THRESHOLD = 0.2 _CKPTS = { "NYU": Path("checkpoints/nyu.ckpt"), "USM": Path("checkpoints/usm.ckpt"), "UCLA": Path("checkpoints/ucla.ckpt"), "UM": Path("checkpoints/um.ckpt"), } _LLM_MODEL = "Yatsuiii/asd-interpreter-lora" SYSTEM_PROMPT = ( "You are a clinical AI assistant specializing in functional MRI brain " "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. " "You interpret outputs from a validated graph neural network (GCN) trained on " "the ABIDE I dataset and provide structured clinical summaries for neurologists " "and psychiatrists. Your reports are informative and evidence-based but always " "clarify that findings are AI-assisted and should be integrated with full " "clinical assessment. You do not make a diagnosis." ) # ── preprocessing ────────────────────────────────────────────────────────── def _zscore(bold): mean = bold.mean(0, keepdims=True) std = bold.std(0, keepdims=True) std[std < 1e-8] = 1.0 return ((bold - mean) / std).astype(np.float32) def _fc(bold): fc = np.corrcoef(bold.T).astype(np.float32) np.nan_to_num(fc, copy=False) return fc def _windows(bold): T, N = bold.shape starts = list(range(0, T - _WINDOW_LEN + 1, _STEP)) w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32) if len(w) >= _MAX_WINDOWS: return w[:_MAX_WINDOWS] return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)]) def preprocess(bold): bold = _zscore(bold) fc = _fc(bold) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = _windows(bold) return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0) # ── GCN model loading ────────────────────────────────────────────────────── _models: list | None = None def get_models(): global _models if _models is not None: return _models from brain_gcn.tasks import ClassificationTask _models = [] for site, ckpt in _CKPTS.items(): if not ckpt.exists(): continue task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False) task.eval() _models.append((site, task)) return _models # ── LLM loading ──────────────────────────────────────────────────────────── _llm = None def get_llm(): global _llm if _llm is not None: return _llm from transformers import AutoModelForCausalLM, AutoTokenizer print(f"Loading LLM: {_LLM_MODEL}") tok = AutoTokenizer.from_pretrained(_LLM_MODEL) tok.pad_token = tok.eos_token mdl = AutoModelForCausalLM.from_pretrained( _LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) mdl.eval() _llm = (mdl, tok) return _llm def _llm_report(p_mean: float, per_model: list) -> str: consensus = sum(1 for _, p in per_model if p > 0.5) per_model_str = "\n".join( f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model ) conf_label = ( "HIGH" if p_mean >= 0.75 else "MODERATE" if p_mean >= 0.6 else "LOW / UNCERTAIN" if p_mean >= 0.4 else "MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)" ) user_msg = ( f"Brain Connectivity GCN Analysis Report\n" f"{'='*40}\n" f"p(ASD) : {p_mean:.3f}\n" f"Confidence Level : {conf_label}\n" f"Model Consensus : {consensus}/4 site-blind models predict ASD\n\n" f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n" f"Please provide a structured clinical interpretation of these findings." ) try: mdl, tok = get_llm() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ] text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device) with torch.no_grad(): out = mdl.generate( **inputs, max_new_tokens=512, temperature=0.3, do_sample=True, pad_token_id=tok.eos_token_id, ) generated = out[0][inputs["input_ids"].shape[1]:] return tok.decode(generated, skip_special_tokens=True).strip() except Exception as e: return f"LLM unavailable: {e}" # ── gradient saliency ────────────────────────────────────────────────────── def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray: maps = [] for _, task in models: adj = adj_t.clone().requires_grad_(True) logits = task.model(bw_t, adj) p = torch.softmax(logits, -1)[0, 1] p.backward() maps.append(adj.grad[0].abs().detach().numpy()) sal = np.mean(maps, axis=0) sal = (sal + sal.T) / 2 return sal def _saliency_figure(sal: np.ndarray, p_mean: float): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image thresh = np.percentile(sal, 95) sal_top = np.where(sal >= thresh, sal, 0.0) roi_imp = sal.sum(1) top20 = roi_imp.argsort()[-20:][::-1] verdict_color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261" fig, axes = plt.subplots(1, 2, figsize=(14, 5.5)) fig.patch.set_facecolor("#0d0d0d") ax = axes[0] ax.set_facecolor("#111") im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest") ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10) ax.set_xlabel("ROI index", color="#777", fontsize=9) ax.set_ylabel("ROI index", color="#777", fontsize=9) ax.tick_params(colors="#555", labelsize=8) cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cb.ax.yaxis.set_tick_params(color="#555", labelsize=7) plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666") for spine in ax.spines.values(): spine.set_color("#333") ax2 = axes[1] ax2.set_facecolor("#111") ax2.barh(range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none") ax2.set_yticks(range(20)) ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc") ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9) ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10) ax2.tick_params(colors="#555", labelsize=8) ax2.invert_yaxis() for spine in ["top", "right"]: ax2.spines[spine].set_visible(False) for spine in ["bottom", "left"]: ax2.spines[spine].set_color("#333") fig.suptitle( f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models", color="#888", fontsize=10, y=1.02, ) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d") plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ── inference ────────────────────────────────────────────────────────────── def run_gcn(file_path: str | None): if file_path is None: return "", "", "", None, "" path = Path(file_path) try: if path.suffix == ".npz": d = np.load(path, allow_pickle=True) fc = d["mean_fc"].astype(np.float32) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = d["bold_windows"].astype(np.float32) if len(bw) >= _MAX_WINDOWS: bw = bw[:_MAX_WINDOWS] else: bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)]) bw_t = torch.FloatTensor(bw).unsqueeze(0) adj_t = torch.FloatTensor(adj).unsqueeze(0) else: bold = np.loadtxt(path, dtype=np.float32) if bold.ndim != 2 or bold.shape[1] != 200: return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None, "" bw_t, adj_t = preprocess(bold) except Exception as e: return f"⚠️ Error loading file: {e}", "", "", None, "" models = get_models() per_model = [] with torch.no_grad(): for site, task in models: logits = task(bw_t, adj_t) p = torch.softmax(logits, -1)[0, 1].item() per_model.append((site, p)) p_mean = float(np.mean([p for _, p in per_model])) consensus = sum(1 for _, p in per_model if p > 0.5) conf = max(p_mean, 1 - p_mean) * 100 try: sal = _compute_saliency(bw_t, adj_t, models) sal_img = _saliency_figure(sal, p_mean) except Exception: sal_img = None # Verdict if p_mean > 0.6: verdict = f"""
ASD INDICATED
Confidence: {conf:.1f}%  |  p(ASD) = {p_mean:.3f}  |  {consensus}/4 site-blind models agree
""" elif p_mean < 0.4: verdict = f"""
TYPICAL CONTROL
Confidence: {conf:.1f}%  |  p(ASD) = {p_mean:.3f}  |  {4-consensus}/4 site-blind models agree
""" else: verdict = f"""
INCONCLUSIVE
Confidence: {conf:.1f}%  |  p(ASD) = {p_mean:.3f}  |  Model disagreement — clinical review required
""" # Ensemble breakdown rows = "" for site, p in per_model: lbl = "ASD" if p > 0.5 else "TC" color = "#e63946" if p > 0.5 else "#2dc653" bar_w = int(p * 100) rows += f""" {site}-blind
{lbl} p={p:.3f} """ ensemble = f"""
Leave-One-Site-Out Ensemble — each model never trained on that site's data
{rows}
Cross-site consensus: {consensus}/4 models agree  ·  LOSO AUC = 0.7872 across 529 held-out subjects
""" # LLM clinical report llm_text = _llm_report(p_mean, per_model) report = f"""
Clinical Report — Qwen2.5-7B fine-tuned on AMD Instinct MI300X
{llm_text}
⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments (ADOS-2, ADI-R).
""" return verdict, ensemble, report, sal_img # ── UI ───────────────────────────────────────────────────────────────────── css = """ body { background: #0d0d0d; } .gradio-container { max-width: 960px; margin: auto; } """ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo: gr.HTML("""
BrainConnect-ASD
Scanner-site-invariant ASD detection from resting-state fMRI
LOSO AUC 0.7872 529 held-out subjects 4 independent institutions AMD Instinct MI300X
""") file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath") verdict_html = gr.HTML() ensemble_html = gr.HTML() gr.HTML("
Gradient Saliency — which brain connections drove this prediction
") saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil") report_html = gr.HTML() file_input.change( fn=run_gcn, inputs=file_input, outputs=[verdict_html, ensemble_html, report_html, saliency_img], ) gr.HTML("""
Adversarial Brain-Mode GCN (k=16) · Qwen2.5-7B LoRA (AMD MI300X) · ABIDE I · GitHub
""") print("Preloading GCN models...") get_models() print("Preloading LLM...") get_llm() print("All models ready.") if __name__ == "__main__": demo.launch()