Yatsuiii's picture
Upload app.py with huggingface_hub
6526502 verified
raw
history blame
15 kB
"""
BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
"""
from __future__ import annotations
import io
from pathlib import Path
import numpy as np
import torch
import gradio as gr
_WINDOW_LEN = 50
_STEP = 3
_MAX_WINDOWS = 30
_FC_THRESHOLD = 0.2
_CKPTS = {
"NYU": Path("checkpoints/nyu.ckpt"),
"USM": Path("checkpoints/usm.ckpt"),
"UCLA": Path("checkpoints/ucla.ckpt"),
"UM": Path("checkpoints/um.ckpt"),
}
# ── preprocessing ──────────────────────────────────────────────────────────
def _zscore(bold):
mean = bold.mean(0, keepdims=True)
std = bold.std(0, keepdims=True)
std[std < 1e-8] = 1.0
return ((bold - mean) / std).astype(np.float32)
def _fc(bold):
fc = np.corrcoef(bold.T).astype(np.float32)
np.nan_to_num(fc, copy=False)
return fc
def _windows(bold):
T, N = bold.shape
starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
if len(w) >= _MAX_WINDOWS:
return w[:_MAX_WINDOWS]
return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
def preprocess(bold):
bold = _zscore(bold)
fc = _fc(bold)
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
bw = _windows(bold)
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
# ── model loading ──────────────────────────────────────────────────────────
_models: list | None = None
def get_models():
global _models
if _models is not None:
return _models
from brain_gcn.tasks import ClassificationTask
_models = []
for site, ckpt in _CKPTS.items():
if not ckpt.exists():
continue
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
task.eval()
_models.append((site, task))
return _models
# ── gradient saliency ──────────────────────────────────────────────────────
def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
"""Gradient of p(ASD) w.r.t. adjacency matrix, averaged over ensemble."""
maps = []
for _, task in models:
adj = adj_t.clone().requires_grad_(True)
logits = task.model(bw_t, adj)
p = torch.softmax(logits, -1)[0, 1]
p.backward()
maps.append(adj.grad[0].abs().detach().numpy())
sal = np.mean(maps, axis=0) # (200, 200)
sal = (sal + sal.T) / 2 # symmetrize
return sal
def _saliency_figure(sal: np.ndarray, p_mean: float):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
thresh = np.percentile(sal, 95)
sal_top = np.where(sal >= thresh, sal, 0.0)
roi_imp = sal.sum(1)
top20 = roi_imp.argsort()[-20:][::-1]
verdict_color = (
"#e63946" if p_mean > 0.6 else
"#2dc653" if p_mean < 0.4 else
"#f4a261"
)
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
fig.patch.set_facecolor("#0d0d0d")
# ── Left: FC edge saliency heatmap ──
ax = axes[0]
ax.set_facecolor("#111")
im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
ax.set_xlabel("ROI index", color="#777", fontsize=9)
ax.set_ylabel("ROI index", color="#777", fontsize=9)
ax.tick_params(colors="#555", labelsize=8)
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
for spine in ax.spines.values():
spine.set_color("#333")
# ── Right: top-20 ROI importance bar chart ──
ax2 = axes[1]
ax2.set_facecolor("#111")
ax2.barh(
range(20), roi_imp[top20],
color=verdict_color, alpha=0.75, edgecolor="none",
)
ax2.set_yticks(range(20))
ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
ax2.tick_params(colors="#555", labelsize=8)
ax2.invert_yaxis()
for spine in ["top", "right"]:
ax2.spines[spine].set_visible(False)
for spine in ["bottom", "left"]:
ax2.spines[spine].set_color("#333")
fig.suptitle(
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
color="#888", fontsize=10, y=1.02,
)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
# ── inference ──────────────────────────────────────────────────────────────
def run_gcn(file_path: str | None):
if file_path is None:
return "", "", "", None
path = Path(file_path)
try:
if path.suffix == ".npz":
d = np.load(path, allow_pickle=True)
fc = d["mean_fc"].astype(np.float32)
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
bw = d["bold_windows"].astype(np.float32)
if len(bw) >= _MAX_WINDOWS:
bw = bw[:_MAX_WINDOWS]
else:
bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
bw_t = torch.FloatTensor(bw).unsqueeze(0)
adj_t = torch.FloatTensor(adj).unsqueeze(0)
else:
bold = np.loadtxt(path, dtype=np.float32)
if bold.ndim != 2 or bold.shape[1] != 200:
return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None
bw_t, adj_t = preprocess(bold)
except Exception as e:
return f"⚠️ Error loading file: {e}", "", "", None
models = get_models()
# ── Ensemble inference (no grad) ──
per_model = []
with torch.no_grad():
for site, task in models:
logits = task(bw_t, adj_t)
p = torch.softmax(logits, -1)[0, 1].item()
per_model.append((site, p))
p_mean = float(np.mean([p for _, p in per_model]))
consensus = sum(1 for _, p in per_model if p > 0.5)
conf = max(p_mean, 1 - p_mean) * 100
# ── Gradient saliency ──
try:
sal = _compute_saliency(bw_t, adj_t, models)
sal_img = _saliency_figure(sal, p_mean)
except Exception:
sal_img = None
# ── Verdict card ──
if p_mean > 0.6:
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
<div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{consensus}/4</b> site-blind models agree</div>
</div>"""
elif p_mean < 0.4:
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
<div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
</div>"""
else:
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
<div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; Model disagreement — clinical review required</div>
</div>"""
# ── Site ensemble breakdown ──
rows = ""
for site, p in per_model:
lbl = "ASD" if p > 0.5 else "TC"
color = "#e63946" if p > 0.5 else "#2dc653"
bar_w = int(p * 100)
rows += f"""<tr>
<td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
<td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
<div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
<td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
<td style="padding:8px 12px;color:#888">p={p:.3f}</td>
</tr>"""
ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
<table style="width:100%;border-collapse:collapse">{rows}</table>
<div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree &nbsp;·&nbsp; LOSO AUC = 0.7872 across 529 held-out subjects</div>
</div>"""
# ── Clinical report ──
if p_mean > 0.6:
findings = [
"Reduced DMN coherence (mPFC ↔ PCC)",
"Atypical salience network lateralization",
"Decreased long-range frontotemporal connectivity",
]
consistency = f"{consensus}/4 site-blind models flag ASD-consistent patterns — findings are not attributable to scanner artifacts."
impression = f"Connectivity profile consistent with ASD ({conf:.1f}% confidence)."
elif p_mean < 0.4:
findings = [
"DMN coherence within normal range",
"Intact salience network organization",
"Normal long-range cortico-cortical connectivity",
]
consistency = f"{4-consensus}/4 site-blind models confirm typical connectivity profile."
impression = f"Connectivity profile within typical range ({conf:.1f}% confidence)."
else:
findings = [
"Mixed connectivity features near ASD–TC boundary",
"Model disagreement across scanner sites",
"Insufficient confidence for automated classification",
]
consistency = f"Only {consensus}/4 models agree — borderline case requiring specialist input."
impression = "Inconclusive. Full neuropsychological evaluation recommended (ADOS-2, ADI-R)."
fi = "".join(f"<li style='margin:6px 0;color:#ccc'>{f}</li>" for f in findings)
report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Connectivity Summary</div>
<div style="color:#eee;font-size:1rem;margin-bottom:16px"><b>Impression:</b> {impression}</div>
<div style="color:#aaa;font-size:0.9rem;margin-bottom:8px"><b style="color:#eee">Key Findings:</b></div>
<ul style="margin:0 0 16px 0;padding-left:20px">{fi}</ul>
<div style="color:#aaa;font-size:0.9rem;margin-bottom:16px"><b style="color:#eee">Cross-Site Consistency:</b> {consistency}</div>
<div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#666;font-size:0.8rem">
⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments.<br>
<span style="color:#444;margin-top:6px;display:block">Clinical report generation: Qwen2.5-7B fine-tuned on AMD Instinct MI300X (coming soon)</span>
</div></div>"""
return verdict, ensemble, report, sal_img
# ── UI ─────────────────────────────────────────────────────────────────────
css = """
body { background: #0d0d0d; }
.gradio-container { max-width: 960px; margin: auto; }
"""
with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
gr.HTML("""
<div style="text-align:center;padding:32px 0 16px">
<div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
<div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
<div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
</div>
</div>
""")
file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
verdict_html = gr.HTML()
ensemble_html = gr.HTML()
with gr.Row():
report_html = gr.HTML()
gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
report_html2 = gr.HTML()
file_input.change(
fn=run_gcn,
inputs=file_input,
outputs=[verdict_html, ensemble_html, report_html2, saliency_img],
)
gr.HTML("""
<div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
Adversarial Brain-Mode GCN (k=16) · ABIDE I (1,102 subjects, 17 sites) ·
<a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
</div>
""")
print("Preloading models...")
get_models()
print("Models ready.")
if __name__ == "__main__":
demo.launch()