""" BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI. """ from __future__ import annotations import io from pathlib import Path import numpy as np import torch import gradio as gr from _charts import VAL_B64, AUC_B64, AMD_BENCH_B64 _WINDOW_LEN = 50 _STEP = 3 _MAX_WINDOWS = 30 _FC_THRESHOLD = 0.2 # ── Atlas configurations ──────────────────────────────────────────────────── # CC200 → Yeo 7-network parcellation (approximate ROI ordering) _ATLAS_CFG = { "cc200": { "n_rois": 200, "label": "CC200", "net_names": ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"], "net_bounds": [0, 38, 69, 99, 137, 165, 180, 200], "net_colors": ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"], "ckpts": { "CALTECH": Path("checkpoints/cc200_caltech.ckpt"), "CMU": Path("checkpoints/cc200_cmu.ckpt"), "KKI": Path("checkpoints/cc200_kki.ckpt"), "LEUVEN_1": Path("checkpoints/cc200_leuven_1.ckpt"), "LEUVEN_2": Path("checkpoints/cc200_leuven_2.ckpt"), "MAX_MUN": Path("checkpoints/cc200_max_mun.ckpt"), "NYU": Path("checkpoints/cc200_nyu.ckpt"), "OHSU": Path("checkpoints/cc200_ohsu.ckpt"), "OLIN": Path("checkpoints/cc200_olin.ckpt"), "PITT": Path("checkpoints/cc200_pitt.ckpt"), "SBL": Path("checkpoints/cc200_sbl.ckpt"), "SDSU": Path("checkpoints/cc200_sdsu.ckpt"), "STANFORD": Path("checkpoints/cc200_stanford.ckpt"), "TRINITY": Path("checkpoints/cc200_trinity.ckpt"), "UCLA_1": Path("checkpoints/cc200_ucla_1.ckpt"), "UCLA_2": Path("checkpoints/cc200_ucla_2.ckpt"), "UM_1": Path("checkpoints/cc200_um_1.ckpt"), "UM_2": Path("checkpoints/cc200_um_2.ckpt"), "USM": Path("checkpoints/cc200_usm.ckpt"), "YALE": Path("checkpoints/cc200_yale.ckpt"), }, }, "aal": { "n_rois": 116, "label": "AAL-116", # Approximate Yeo-7 parcellation for AAL-116 anatomical ordering: # Frontal/FPN (1-28), Sensorimotor (29-40), DMN parietal (41-60), # Temporal/DMN (61-76), Subcortical (77-90), Occipital/Visual (91-116) "net_names": ["Frontoparietal", "Sensorimotor", "Dorsal Attn", "DMN", "Salience", "Subcortical", "Visual"], "net_bounds": [0, 20, 34, 50, 68, 80, 92, 116], "net_colors": ["#457b9d", "#2dc653", "#8b5cf6", "#e63946", "#f4a261", "#6b7280", "#a8dadc"], "ckpts": { "CALTECH": Path("checkpoints/aal_caltech.ckpt"), "CMU": Path("checkpoints/aal_cmu.ckpt"), "KKI": Path("checkpoints/aal_kki.ckpt"), "LEUVEN_1": Path("checkpoints/aal_leuven_1.ckpt"), "LEUVEN_2": Path("checkpoints/aal_leuven_2.ckpt"), "MAX_MUN": Path("checkpoints/aal_max_mun.ckpt"), "NYU": Path("checkpoints/aal_nyu.ckpt"), "OHSU": Path("checkpoints/aal_ohsu.ckpt"), "OLIN": Path("checkpoints/aal_olin.ckpt"), "PITT": Path("checkpoints/aal_pitt.ckpt"), "SBL": Path("checkpoints/aal_sbl.ckpt"), "SDSU": Path("checkpoints/aal_sdsu.ckpt"), "STANFORD": Path("checkpoints/aal_stanford.ckpt"), "TRINITY": Path("checkpoints/aal_trinity.ckpt"), "UCLA_1": Path("checkpoints/aal_ucla_1.ckpt"), "UCLA_2": Path("checkpoints/aal_ucla_2.ckpt"), "UM_1": Path("checkpoints/aal_um_1.ckpt"), "UM_2": Path("checkpoints/aal_um_2.ckpt"), "USM": Path("checkpoints/aal_usm.ckpt"), "YALE": Path("checkpoints/aal_yale.ckpt"), }, }, "ho": { "n_rois": 111, "label": "Harvard-Oxford", "net_names": ["Frontoparietal", "Sensorimotor", "DMN", "Salience", "Subcortical", "Visual", "Temporal"], "net_bounds": [0, 18, 30, 48, 68, 80, 96, 111], "net_colors": ["#457b9d", "#2dc653", "#e63946", "#f4a261", "#6b7280", "#a8dadc", "#8b5cf6"], "ckpts": { "NYU": Path("checkpoints/ho_nyu.ckpt"), "USM": Path("checkpoints/ho_usm.ckpt"), "UCLA": Path("checkpoints/ho_ucla.ckpt"), "UM": Path("checkpoints/ho_um.ckpt"), }, }, } # Resolve active atlas config by ROI count _ROI_TO_ATLAS = {cfg["n_rois"]: key for key, cfg in _ATLAS_CFG.items()} # Legacy aliases kept for backward compat _NET_NAMES = _ATLAS_CFG["cc200"]["net_names"] _NET_BOUNDS = _ATLAS_CFG["cc200"]["net_bounds"] _NET_COLORS = _ATLAS_CFG["cc200"]["net_colors"] _CKPTS = _ATLAS_CFG["cc200"]["ckpts"] # ── preprocessing ────────────────────────────────────────────────────────── def _zscore(bold): mean = bold.mean(0, keepdims=True) std = bold.std(0, keepdims=True) std[std < 1e-8] = 1.0 return ((bold - mean) / std).astype(np.float32) def _fc(bold): fc = np.corrcoef(bold.T).astype(np.float32) np.nan_to_num(fc, copy=False) return fc def _windows(bold): T, N = bold.shape starts = list(range(0, T - _WINDOW_LEN + 1, _STEP)) w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32) if len(w) >= _MAX_WINDOWS: return w[:_MAX_WINDOWS] return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)]) def preprocess(bold): bold = _zscore(bold) fc = _fc(bold) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = _windows(bold) return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0) # ── LLM (Qwen2.5-7B LoRA fine-tuned on AMD MI300X) ──────────────────────── _LLM_MODEL = "Yatsuiii/asd-interpreter-lora" _SYSTEM_PROMPT = ( "You are a clinical AI assistant specializing in functional MRI brain " "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. " "You interpret outputs from a validated graph neural network (GCN) trained on " "the ABIDE I dataset (1,102 subjects, 20 acquisition sites) and provide structured " "clinical summaries for neurologists and psychiatrists. " "CRITICAL RULES: (1) Only reference brain networks, connectivity patterns, and " "statistics that are explicitly provided in the input report — do NOT invent or " "hallucinate network names, connectivity findings, or numeric values. " "(2) Base every observation directly on the per-network saliency scores and " "ensemble probabilities given in the input. (3) If a network is not listed in the " "input, do not mention it. (4) Always clarify findings are AI-assisted and require " "full clinical assessment. You do not make a diagnosis." ) _llm_cache = None def get_llm(): global _llm_cache if _llm_cache is not None: return _llm_cache from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained(_LLM_MODEL) tok.pad_token = tok.eos_token mdl = AutoModelForCausalLM.from_pretrained( _LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto" ) mdl.eval() _llm_cache = (mdl, tok) return _llm_cache def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str: consensus = sum(1 for _, p in per_model if p > 0.5) per_model_str = "\n".join( f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model ) conf_label = ( "HIGH" if p_mean >= 0.75 else "MODERATE" if p_mean >= 0.6 else "LOW / UNCERTAIN" if p_mean >= 0.4 else "MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)" ) sal_section = "" if net_saliency: sorted_nets = sorted(net_saliency.items(), key=lambda x: x[1], reverse=True) sal_lines = "\n".join( f" {name}: {score:.5f}" for name, score in sorted_nets ) sal_section = ( f"\nPer-Network Gradient Saliency (ranked high→low, actual GCN values):\n" f"{sal_lines}\n" f"[ONLY reference these networks with these exact values — no others.]\n" ) user_msg = ( f"Brain Connectivity GCN Analysis Report\n{'='*40}\n" f"Dataset : ABIDE I · 1,102 subjects · 20 acquisition sites\n" f"p(ASD) : {p_mean:.3f}\n" f"Confidence Level : {conf_label}\n" f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n" f"{sal_section}\n" f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n" f"Provide a structured clinical interpretation referencing ONLY the networks " f"and values listed above. Do not mention any network not in this report." ) try: mdl, tok = get_llm() messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ] text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device) with torch.no_grad(): out = mdl.generate( **inputs, max_new_tokens=512, temperature=0.1, do_sample=True, pad_token_id=tok.eos_token_id, ) generated = out[0][inputs["input_ids"].shape[1]:] return tok.decode(generated, skip_special_tokens=True).strip() except Exception as e: return f"[LLM unavailable: {e}]" # ── model loading ────────────────────────────────────────────────────────── _model_cache: dict[str, list] = {} def get_models(atlas: str = "cc200"): global _model_cache if atlas in _model_cache: return _model_cache[atlas] from brain_gcn.tasks import ClassificationTask cfg = _ATLAS_CFG.get(atlas, _ATLAS_CFG["cc200"]) models = [] for site, ckpt in cfg["ckpts"].items(): if not ckpt.exists(): continue task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False) task.eval() models.append((site, task)) _model_cache[atlas] = models return models # ── gradient saliency ────────────────────────────────────────────────────── def _compute_saliency(bw_t, adj_t, models): maps = [] for _, task in models: adj = adj_t.clone().requires_grad_(True) logits = task.model(bw_t, adj) torch.softmax(logits, -1)[0, 1].backward() maps.append(adj.grad[0].abs().detach().numpy()) sal = np.mean(maps, axis=0) return (sal + sal.T) / 2 # Approximate MNI centroids for each CC200 network (mm), used for 3D brain view _NET_MNI = np.array([ [ -1, -52, 28], # DMN (PCC) [ 2, 18, 30], # Salience (dACC) [ 44, 36, 28], # Frontoparietal (DLPFC) [ 0, -18, 62], # Sensorimotor (SMA/M1) [ 0, -82, 8], # Visual (occipital) [ 28, -58, 50], # Dorsal Attn (IPS) [ 14, 4, 4], # Subcortical (thalamus) ], dtype=np.float32) def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=None): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 from mpl_toolkits.mplot3d.art3d import Line3DCollection from PIL import Image _nn = net_names if net_names is not None else _NET_NAMES _nb = net_bounds if net_bounds is not None else _NET_BOUNDS _nc = net_colors if net_colors is not None else _NET_COLORS n_nets = len(_nn) # Aggregate N×N saliency → 7×7 network-level matrix net_sal = np.zeros((n_nets, n_nets)) for i, (s1, e1) in enumerate(zip(_nb[:-1], _nb[1:])): for j, (s2, e2) in enumerate(zip(_nb[:-1], _nb[1:])): net_sal[i, j] = sal[s1:e1, s2:e2].mean() # Network importance: mean outgoing + incoming saliency per network net_imp = np.array([ sal[s:e, :].mean() + sal[:, s:e].mean() for s, e in zip(_nb[:-1], _nb[1:]) ]) fig = plt.figure(figsize=(18, 5.5)) fig.patch.set_facecolor("#0e1015") axes = [ fig.add_subplot(1, 3, 1), fig.add_subplot(1, 3, 2), fig.add_subplot(1, 3, 3, projection="3d"), ] # ── Left: 7×7 network heatmap ────────────────────────────────────────── ax = axes[0] ax.set_facecolor("#161922") im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest") ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=11, pad=14, fontweight="bold") ax.set_xticks(range(n_nets)) ax.set_yticks(range(n_nets)) ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=9, color="#ccc") ax.set_yticklabels(_nn, fontsize=9, color="#ccc") ax.tick_params(colors="#555", length=0) for sp in ax.spines.values(): sp.set_color("#222") # Boundary lines between networks for k in range(1, n_nets): ax.axhline(k - 0.5, color="#2a2a2a", lw=1.0) ax.axvline(k - 0.5, color="#2a2a2a", lw=1.0) # Find top-5 off-diagonal edges (i != j) and top-3 for callouts vmax = net_sal.max() edge_scores = [] for i in range(n_nets): for j in range(n_nets): if i != j: edge_scores.append((net_sal[i, j], i, j)) edge_scores.sort(reverse=True) top5_cells = {(i, j) for _, i, j in edge_scores[:5]} top3_edges = edge_scores[:3] # Annotate each cell with its value; highlight top-5 with white border for i in range(n_nets): for j in range(n_nets): txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666" ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center", fontsize=6.5, color=txt_color, zorder=3) if (i, j) in top5_cells: rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96, linewidth=1.8, edgecolor="#ffffff", facecolor="none", zorder=4) ax.add_patch(rect) # Callout labels for top-3 cross-network edges for rank, (score, i, j) in enumerate(top3_edges): label = f"#{rank+1} {_nn[i]}↔{_nn[j]}" ax.annotate(label, xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3), fontsize=6, color="#fb923c", fontweight="600", arrowprops=dict(arrowstyle="-", color="#fb923c", lw=0.7, connectionstyle="arc3,rad=0.1"), ha="left", va="center", zorder=5) cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cb.ax.yaxis.set_tick_params(color="#444", labelsize=7) plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555") cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=7.5) # ── Right: network importance bar chart ──────────────────────────────── ax2 = axes[1] ax2.set_facecolor("#161922") ax2.tick_params(colors="#555", labelsize=9) order = net_imp.argsort()[::-1] bars = ax2.barh(range(n_nets), net_imp[order], color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65) ax2.set_yticks(range(n_nets)) ax2.set_yticklabels([_nn[i] for i in order], fontsize=9.5, color="#ddd") ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9) ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold") ax2.invert_yaxis() for sp in ["top", "right"]: ax2.spines[sp].set_visible(False) for sp in ["bottom", "left"]: ax2.spines[sp].set_color("#222") # Value labels on bars x_max = net_imp.max() for bar, val in zip(bars, net_imp[order]): ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2, f"{val:.4f}", va="center", color="#555", fontsize=7.5) # ── 3D Brain Surface — top connections ──────────────────────────────────── ax3 = axes[2] ax3.set_facecolor("#0e1015") ax3.grid(False) ax3.set_axis_off() ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=11, pad=4, fontweight="bold") # Transparent brain ellipsoid wireframe (MNI space approx) u = np.linspace(0, 2 * np.pi, 32) v = np.linspace(0, np.pi, 20) ex = 68 * np.outer(np.cos(u), np.sin(v)) ey = 85 * np.outer(np.sin(u), np.sin(v)) - 10 ez = 60 * np.outer(np.ones_like(u), np.cos(v)) + 28 ax3.plot_wireframe(ex, ey, ez, color="#252a35", linewidth=0.25, alpha=0.45, zorder=0) # Network nodes — size proportional to importance imp_norm = (net_imp - net_imp.min()) / (net_imp.max() - net_imp.min() + 1e-9) for k, (name, color) in enumerate(zip(_NET_NAMES, _NET_COLORS)): x, y, z = _NET_MNI[k] size = 60 + imp_norm[k] * 260 ax3.scatter([x], [y], [z], c=color, s=size, zorder=5, edgecolors="#ffffff", linewidths=0.5, alpha=0.92) ax3.text(x, y, z + 7, name, fontsize=5.5, color=color, ha="center", va="bottom", fontweight="600", zorder=6) # Draw top-5 inter-network connections as lines, thickness ∝ saliency sal_vals = [s for s, _, _ in edge_scores[:5]] sal_min, sal_max = min(sal_vals), max(sal_vals) + 1e-9 for rank, (score, ni, nj) in enumerate(edge_scores[:5]): p1, p2 = _NET_MNI[ni], _NET_MNI[nj] lw = 0.8 + 2.5 * (score - sal_min) / (sal_max - sal_min) alph = 0.5 + 0.45 * (score - sal_min) / (sal_max - sal_min) clr = "#fb923c" if rank == 0 else "#f4f4f5" ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], color=clr, linewidth=lw, alpha=alph, zorder=4) ax3.view_init(elev=22, azim=-65) ax3.set_box_aspect([1.2, 1.4, 1.0]) fig.suptitle( f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(models)}-model LOSO ensemble · CC200 → Yeo-7 networks", color="#444", fontsize=8.5, y=1.02, ) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=140, bbox_inches="tight", facecolor="#0e1015") plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ── inference ────────────────────────────────────────────────────────────── def run_gcn(file_path): if file_path is None: return "", "", "", None path = Path(file_path) atlas_key = "cc200" # default; overridden below for .1D files try: if path.suffix == ".npz": d = np.load(path, allow_pickle=True) fc = d["mean_fc"].astype(np.float32) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = d["bold_windows"].astype(np.float32) if len(bw) >= _MAX_WINDOWS: bw = bw[:_MAX_WINDOWS] else: bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)]) bw_t = torch.FloatTensor(bw).unsqueeze(0) adj_t = torch.FloatTensor(adj).unsqueeze(0) else: bold = np.loadtxt(path, dtype=np.float32) if bold.ndim != 2: return "
rois_cc200/, rois_aal/, or rois_ho/"
f"aws s3 cp s3://fcp-indi/.../rois_cc200/ . --no-sign-request --recursive"
f"| SVM + FC (Plitt 2015) | 0.71 |
| BrainNetCNN (Kawahara 2017) | 0.74 |
| GCN + FC (Ktena 2018) | 0.70 |
| ABIDE site-specific SVM | 0.76 |
| BrainConnect-ASD (LOSO) | 0.7260 |
M_kl = v_k · FC · v_l
.1D or .npz fMRI time-series file