""" BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI. """ from __future__ import annotations import io from pathlib import Path import numpy as np import torch import gradio as gr _WINDOW_LEN = 50 _STEP = 3 _MAX_WINDOWS = 30 _FC_THRESHOLD = 0.2 _CKPTS = { "NYU": Path("checkpoints/nyu.ckpt"), "USM": Path("checkpoints/usm.ckpt"), "UCLA": Path("checkpoints/ucla.ckpt"), "UM": Path("checkpoints/um.ckpt"), } # ── preprocessing ────────────────────────────────────────────────────────── def _zscore(bold): mean = bold.mean(0, keepdims=True) std = bold.std(0, keepdims=True) std[std < 1e-8] = 1.0 return ((bold - mean) / std).astype(np.float32) def _fc(bold): fc = np.corrcoef(bold.T).astype(np.float32) np.nan_to_num(fc, copy=False) return fc def _windows(bold): T, N = bold.shape starts = list(range(0, T - _WINDOW_LEN + 1, _STEP)) w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32) if len(w) >= _MAX_WINDOWS: return w[:_MAX_WINDOWS] return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)]) def preprocess(bold): bold = _zscore(bold) fc = _fc(bold) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = _windows(bold) return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0) # ── model loading ────────────────────────────────────────────────────────── _models: list | None = None def get_models(): global _models if _models is not None: return _models from brain_gcn.tasks import ClassificationTask _models = [] for site, ckpt in _CKPTS.items(): if not ckpt.exists(): continue task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False) task.eval() _models.append((site, task)) return _models # ── gradient saliency ────────────────────────────────────────────────────── def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray: """Gradient of p(ASD) w.r.t. adjacency matrix, averaged over ensemble.""" maps = [] for _, task in models: adj = adj_t.clone().requires_grad_(True) logits = task.model(bw_t, adj) p = torch.softmax(logits, -1)[0, 1] p.backward() maps.append(adj.grad[0].abs().detach().numpy()) sal = np.mean(maps, axis=0) # (200, 200) sal = (sal + sal.T) / 2 # symmetrize return sal def _saliency_figure(sal: np.ndarray, p_mean: float): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from PIL import Image thresh = np.percentile(sal, 95) sal_top = np.where(sal >= thresh, sal, 0.0) roi_imp = sal.sum(1) top20 = roi_imp.argsort()[-20:][::-1] verdict_color = ( "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261" ) fig, axes = plt.subplots(1, 2, figsize=(14, 5.5)) fig.patch.set_facecolor("#0d0d0d") # ── Left: FC edge saliency heatmap ── ax = axes[0] ax.set_facecolor("#111") im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest") ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10) ax.set_xlabel("ROI index", color="#777", fontsize=9) ax.set_ylabel("ROI index", color="#777", fontsize=9) ax.tick_params(colors="#555", labelsize=8) cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) cb.ax.yaxis.set_tick_params(color="#555", labelsize=7) plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666") for spine in ax.spines.values(): spine.set_color("#333") # ── Right: top-20 ROI importance bar chart ── ax2 = axes[1] ax2.set_facecolor("#111") ax2.barh( range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none", ) ax2.set_yticks(range(20)) ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc") ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9) ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10) ax2.tick_params(colors="#555", labelsize=8) ax2.invert_yaxis() for spine in ["top", "right"]: ax2.spines[spine].set_visible(False) for spine in ["bottom", "left"]: ax2.spines[spine].set_color("#333") fig.suptitle( f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models", color="#888", fontsize=10, y=1.02, ) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d") plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ── inference ────────────────────────────────────────────────────────────── def run_gcn(file_path: str | None): if file_path is None: return "", "", "", None path = Path(file_path) try: if path.suffix == ".npz": d = np.load(path, allow_pickle=True) fc = d["mean_fc"].astype(np.float32) fc = np.arctanh(np.clip(fc, -0.9999, 0.9999)) adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32) bw = d["bold_windows"].astype(np.float32) if len(bw) >= _MAX_WINDOWS: bw = bw[:_MAX_WINDOWS] else: bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)]) bw_t = torch.FloatTensor(bw).unsqueeze(0) adj_t = torch.FloatTensor(adj).unsqueeze(0) else: bold = np.loadtxt(path, dtype=np.float32) if bold.ndim != 2 or bold.shape[1] != 200: return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None bw_t, adj_t = preprocess(bold) except Exception as e: return f"⚠️ Error loading file: {e}", "", "", None models = get_models() # ── Ensemble inference (no grad) ── per_model = [] with torch.no_grad(): for site, task in models: logits = task(bw_t, adj_t) p = torch.softmax(logits, -1)[0, 1].item() per_model.append((site, p)) p_mean = float(np.mean([p for _, p in per_model])) consensus = sum(1 for _, p in per_model if p > 0.5) conf = max(p_mean, 1 - p_mean) * 100 # ── Gradient saliency ── try: sal = _compute_saliency(bw_t, adj_t, models) sal_img = _saliency_figure(sal, p_mean) except Exception: sal_img = None # ── Verdict card ── if p_mean > 0.6: verdict = f"""