| """ |
| BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI. |
| """ |
| from __future__ import annotations |
|
|
| import io |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import gradio as gr |
|
|
| _WINDOW_LEN = 50 |
| _STEP = 3 |
| _MAX_WINDOWS = 30 |
| _FC_THRESHOLD = 0.2 |
|
|
| _CKPTS = { |
| "NYU": Path("checkpoints/nyu.ckpt"), |
| "USM": Path("checkpoints/usm.ckpt"), |
| "UCLA": Path("checkpoints/ucla.ckpt"), |
| "UM": Path("checkpoints/um.ckpt"), |
| } |
|
|
| |
|
|
| def _zscore(bold): |
| mean = bold.mean(0, keepdims=True) |
| std = bold.std(0, keepdims=True) |
| std[std < 1e-8] = 1.0 |
| return ((bold - mean) / std).astype(np.float32) |
|
|
| def _fc(bold): |
| fc = np.corrcoef(bold.T).astype(np.float32) |
| np.nan_to_num(fc, copy=False) |
| return fc |
|
|
| def _windows(bold): |
| T, N = bold.shape |
| starts = list(range(0, T - _WINDOW_LEN + 1, _STEP)) |
| w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32) |
| if len(w) >= _MAX_WINDOWS: |
| return w[:_MAX_WINDOWS] |
| return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)]) |
|
|
| def preprocess(bold): |
| bold = _zscore(bold) |
| fc = _fc(bold) |
| fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) |
| adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) |
| bw = _windows(bold) |
| return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0) |
|
|
| |
|
|
| _models: list | None = None |
|
|
| def get_models(): |
| global _models |
| if _models is not None: |
| return _models |
| from brain_gcn.tasks import ClassificationTask |
| _models = [] |
| for site, ckpt in _CKPTS.items(): |
| if not ckpt.exists(): |
| continue |
| task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False) |
| task.eval() |
| _models.append((site, task)) |
| return _models |
|
|
| |
|
|
| def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray: |
| """Gradient of p(ASD) w.r.t. adjacency matrix, averaged over ensemble.""" |
| maps = [] |
| for _, task in models: |
| adj = adj_t.clone().requires_grad_(True) |
| logits = task.model(bw_t, adj) |
| p = torch.softmax(logits, -1)[0, 1] |
| p.backward() |
| maps.append(adj.grad[0].abs().detach().numpy()) |
| sal = np.mean(maps, axis=0) |
| sal = (sal + sal.T) / 2 |
| return sal |
|
|
|
|
| def _saliency_figure(sal: np.ndarray, p_mean: float): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from PIL import Image |
|
|
| thresh = np.percentile(sal, 95) |
| sal_top = np.where(sal >= thresh, sal, 0.0) |
|
|
| roi_imp = sal.sum(1) |
| top20 = roi_imp.argsort()[-20:][::-1] |
|
|
| verdict_color = ( |
| "#e63946" if p_mean > 0.6 else |
| "#2dc653" if p_mean < 0.4 else |
| "#f4a261" |
| ) |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5.5)) |
| fig.patch.set_facecolor("#0d0d0d") |
|
|
| |
| ax = axes[0] |
| ax.set_facecolor("#111") |
| im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest") |
| ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10) |
| ax.set_xlabel("ROI index", color="#777", fontsize=9) |
| ax.set_ylabel("ROI index", color="#777", fontsize=9) |
| ax.tick_params(colors="#555", labelsize=8) |
| cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
| cb.ax.yaxis.set_tick_params(color="#555", labelsize=7) |
| plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666") |
| for spine in ax.spines.values(): |
| spine.set_color("#333") |
|
|
| |
| ax2 = axes[1] |
| ax2.set_facecolor("#111") |
| ax2.barh( |
| range(20), roi_imp[top20], |
| color=verdict_color, alpha=0.75, edgecolor="none", |
| ) |
| ax2.set_yticks(range(20)) |
| ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc") |
| ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9) |
| ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10) |
| ax2.tick_params(colors="#555", labelsize=8) |
| ax2.invert_yaxis() |
| for spine in ["top", "right"]: |
| ax2.spines[spine].set_visible(False) |
| for spine in ["bottom", "left"]: |
| ax2.spines[spine].set_color("#333") |
|
|
| fig.suptitle( |
| f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models", |
| color="#888", fontsize=10, y=1.02, |
| ) |
|
|
| plt.tight_layout() |
| buf = io.BytesIO() |
| plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d") |
| plt.close(fig) |
| buf.seek(0) |
| return Image.open(buf).copy() |
|
|
| |
|
|
| def run_gcn(file_path: str | None): |
| if file_path is None: |
| return "", "", "", None |
|
|
| path = Path(file_path) |
| try: |
| if path.suffix == ".npz": |
| d = np.load(path, allow_pickle=True) |
| fc = d["mean_fc"].astype(np.float32) |
| fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) |
| adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) |
| bw = d["bold_windows"].astype(np.float32) |
| if len(bw) >= _MAX_WINDOWS: |
| bw = bw[:_MAX_WINDOWS] |
| else: |
| bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)]) |
| bw_t = torch.FloatTensor(bw).unsqueeze(0) |
| adj_t = torch.FloatTensor(adj).unsqueeze(0) |
| else: |
| bold = np.loadtxt(path, dtype=np.float32) |
| if bold.ndim != 2 or bold.shape[1] != 200: |
| return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None |
| bw_t, adj_t = preprocess(bold) |
| except Exception as e: |
| return f"⚠️ Error loading file: {e}", "", "", None |
|
|
| models = get_models() |
|
|
| |
| per_model = [] |
| with torch.no_grad(): |
| for site, task in models: |
| logits = task(bw_t, adj_t) |
| p = torch.softmax(logits, -1)[0, 1].item() |
| per_model.append((site, p)) |
|
|
| p_mean = float(np.mean([p for _, p in per_model])) |
| consensus = sum(1 for _, p in per_model if p > 0.5) |
| conf = max(p_mean, 1 - p_mean) * 100 |
|
|
| |
| try: |
| sal = _compute_saliency(bw_t, adj_t, models) |
| sal_img = _saliency_figure(sal, p_mean) |
| except Exception: |
| sal_img = None |
|
|
| |
| if p_mean > 0.6: |
| verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px"> |
| <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div> |
| <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{consensus}/4</b> site-blind models agree</div> |
| </div>""" |
| elif p_mean < 0.4: |
| verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px"> |
| <div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div> |
| <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{4-consensus}/4</b> site-blind models agree</div> |
| </div>""" |
| else: |
| verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px"> |
| <div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div> |
| <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | Model disagreement — clinical review required</div> |
| </div>""" |
|
|
| |
| rows = "" |
| for site, p in per_model: |
| lbl = "ASD" if p > 0.5 else "TC" |
| color = "#e63946" if p > 0.5 else "#2dc653" |
| bar_w = int(p * 100) |
| rows += f"""<tr> |
| <td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td> |
| <td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px"> |
| <div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td> |
| <td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td> |
| <td style="padding:8px 12px;color:#888">p={p:.3f}</td> |
| </tr>""" |
|
|
| ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px"> |
| <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div> |
| <table style="width:100%;border-collapse:collapse">{rows}</table> |
| <div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree · LOSO AUC = 0.7872 across 529 held-out subjects</div> |
| </div>""" |
|
|
| |
| if p_mean > 0.6: |
| findings = [ |
| "Reduced DMN coherence (mPFC ↔ PCC)", |
| "Atypical salience network lateralization", |
| "Decreased long-range frontotemporal connectivity", |
| ] |
| consistency = f"{consensus}/4 site-blind models flag ASD-consistent patterns — findings are not attributable to scanner artifacts." |
| impression = f"Connectivity profile consistent with ASD ({conf:.1f}% confidence)." |
| elif p_mean < 0.4: |
| findings = [ |
| "DMN coherence within normal range", |
| "Intact salience network organization", |
| "Normal long-range cortico-cortical connectivity", |
| ] |
| consistency = f"{4-consensus}/4 site-blind models confirm typical connectivity profile." |
| impression = f"Connectivity profile within typical range ({conf:.1f}% confidence)." |
| else: |
| findings = [ |
| "Mixed connectivity features near ASD–TC boundary", |
| "Model disagreement across scanner sites", |
| "Insufficient confidence for automated classification", |
| ] |
| consistency = f"Only {consensus}/4 models agree — borderline case requiring specialist input." |
| impression = "Inconclusive. Full neuropsychological evaluation recommended (ADOS-2, ADI-R)." |
|
|
| fi = "".join(f"<li style='margin:6px 0;color:#ccc'>{f}</li>" for f in findings) |
| report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px"> |
| <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Connectivity Summary</div> |
| <div style="color:#eee;font-size:1rem;margin-bottom:16px"><b>Impression:</b> {impression}</div> |
| <div style="color:#aaa;font-size:0.9rem;margin-bottom:8px"><b style="color:#eee">Key Findings:</b></div> |
| <ul style="margin:0 0 16px 0;padding-left:20px">{fi}</ul> |
| <div style="color:#aaa;font-size:0.9rem;margin-bottom:16px"><b style="color:#eee">Cross-Site Consistency:</b> {consistency}</div> |
| <div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#666;font-size:0.8rem"> |
| ⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments.<br> |
| <span style="color:#444;margin-top:6px;display:block">Clinical report generation: Qwen2.5-7B fine-tuned on AMD Instinct MI300X (coming soon)</span> |
| </div></div>""" |
|
|
| return verdict, ensemble, report, sal_img |
|
|
|
|
| |
|
|
| css = """ |
| body { background: #0d0d0d; } |
| .gradio-container { max-width: 960px; margin: auto; } |
| """ |
|
|
| with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo: |
| gr.HTML(""" |
| <div style="text-align:center;padding:32px 0 16px"> |
| <div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div> |
| <div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div> |
| <div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap"> |
| <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span> |
| <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span> |
| <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span> |
| <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span> |
| </div> |
| </div> |
| """) |
|
|
| file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath") |
|
|
| verdict_html = gr.HTML() |
| ensemble_html = gr.HTML() |
|
|
| with gr.Row(): |
| report_html = gr.HTML() |
|
|
| gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>") |
| saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil") |
|
|
| report_html2 = gr.HTML() |
|
|
| file_input.change( |
| fn=run_gcn, |
| inputs=file_input, |
| outputs=[verdict_html, ensemble_html, report_html2, saliency_img], |
| ) |
|
|
| gr.HTML(""" |
| <div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem"> |
| Adversarial Brain-Mode GCN (k=16) · ABIDE I (1,102 subjects, 17 sites) · |
| <a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a> |
| </div> |
| """) |
|
|
| print("Preloading models...") |
| get_models() |
| print("Models ready.") |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|