Yatsuiii's picture
Upload app.py with huggingface_hub
3d9e4e3 verified
raw
history blame
42.2 kB
"""
BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
"""
from __future__ import annotations
import io
from pathlib import Path
import numpy as np
import torch
import gradio as gr
from _charts import VAL_B64, AUC_B64, AMD_BENCH_B64
_WINDOW_LEN = 50
_STEP = 3
_MAX_WINDOWS = 30
_FC_THRESHOLD = 0.2
# CC200 atlas (Craddock 2012) → approximate Yeo 7-network parcellation
_NET_NAMES = ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"]
_NET_BOUNDS = [0, 38, 69, 99, 137, 165, 180, 200]
_NET_COLORS = ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"]
_CKPTS = {
"NYU": Path("checkpoints/nyu.ckpt"),
"USM": Path("checkpoints/usm.ckpt"),
"UCLA": Path("checkpoints/ucla.ckpt"),
"UM": Path("checkpoints/um.ckpt"),
}
# ── preprocessing ──────────────────────────────────────────────────────────
def _zscore(bold):
mean = bold.mean(0, keepdims=True)
std = bold.std(0, keepdims=True)
std[std < 1e-8] = 1.0
return ((bold - mean) / std).astype(np.float32)
def _fc(bold):
fc = np.corrcoef(bold.T).astype(np.float32)
np.nan_to_num(fc, copy=False)
return fc
def _windows(bold):
T, N = bold.shape
starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
if len(w) >= _MAX_WINDOWS:
return w[:_MAX_WINDOWS]
return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
def preprocess(bold):
bold = _zscore(bold)
fc = _fc(bold)
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
bw = _windows(bold)
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
# ── model loading ──────────────────────────────────────────────────────────
_models = None
def get_models():
global _models
if _models is not None:
return _models
from brain_gcn.tasks import ClassificationTask
_models = []
for site, ckpt in _CKPTS.items():
if not ckpt.exists():
continue
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
task.eval()
_models.append((site, task))
return _models
# ── gradient saliency ──────────────────────────────────────────────────────
def _compute_saliency(bw_t, adj_t, models):
maps = []
for _, task in models:
adj = adj_t.clone().requires_grad_(True)
logits = task.model(bw_t, adj)
torch.softmax(logits, -1)[0, 1].backward()
maps.append(adj.grad[0].abs().detach().numpy())
sal = np.mean(maps, axis=0)
return (sal + sal.T) / 2
def _saliency_figure(sal, p_mean):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
n_nets = len(_NET_NAMES)
# Aggregate 200×200 saliency → 7×7 network-level matrix
net_sal = np.zeros((n_nets, n_nets))
for i, (s1, e1) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
for j, (s2, e2) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
net_sal[i, j] = sal[s1:e1, s2:e2].mean()
# Network importance: mean outgoing + incoming saliency per network
net_imp = np.array([
sal[s:e, :].mean() + sal[:, s:e].mean()
for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
])
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
fig.patch.set_facecolor("#0d0d0d")
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
ax = axes[0]
ax.set_facecolor("#161922")
im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=11, pad=14, fontweight="bold")
ax.set_xticks(range(n_nets))
ax.set_yticks(range(n_nets))
ax.set_xticklabels(_NET_NAMES, rotation=40, ha="right", fontsize=9, color="#ccc")
ax.set_yticklabels(_NET_NAMES, fontsize=9, color="#ccc")
ax.tick_params(colors="#555", length=0)
for sp in ax.spines.values():
sp.set_color("#222")
# Boundary lines between networks
for k in range(1, n_nets):
ax.axhline(k - 0.5, color="#2a2a2a", lw=1.0)
ax.axvline(k - 0.5, color="#2a2a2a", lw=1.0)
# Find top-5 off-diagonal edges (i != j) and top-3 for callouts
vmax = net_sal.max()
edge_scores = []
for i in range(n_nets):
for j in range(n_nets):
if i != j:
edge_scores.append((net_sal[i, j], i, j))
edge_scores.sort(reverse=True)
top5_cells = {(i, j) for _, i, j in edge_scores[:5]}
top3_edges = edge_scores[:3]
# Annotate each cell with its value; highlight top-5 with white border
for i in range(n_nets):
for j in range(n_nets):
txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666"
ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
fontsize=6.5, color=txt_color, zorder=3)
if (i, j) in top5_cells:
rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96,
linewidth=1.8, edgecolor="#ffffff",
facecolor="none", zorder=4)
ax.add_patch(rect)
# Callout labels for top-3 cross-network edges
for rank, (score, i, j) in enumerate(top3_edges):
label = f"#{rank+1} {_NET_NAMES[i]}{_NET_NAMES[j]}"
ax.annotate(label,
xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
fontsize=6, color="#fb923c", fontweight="600",
arrowprops=dict(arrowstyle="-", color="#fb923c",
lw=0.7, connectionstyle="arc3,rad=0.1"),
ha="left", va="center", zorder=5)
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=7.5)
# ── Right: network importance bar chart ────────────────────────────────
ax2 = axes[1]
ax2.set_facecolor("#161922")
ax2.tick_params(colors="#555", labelsize=9)
order = net_imp.argsort()[::-1]
bars = ax2.barh(range(n_nets), net_imp[order],
color=[_NET_COLORS[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
ax2.set_yticks(range(n_nets))
ax2.set_yticklabels([_NET_NAMES[i] for i in order], fontsize=9.5, color="#ddd")
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
ax2.invert_yaxis()
for sp in ["top", "right"]:
ax2.spines[sp].set_visible(False)
for sp in ["bottom", "left"]:
ax2.spines[sp].set_color("#222")
# Value labels on bars
x_max = net_imp.max()
for bar, val in zip(bars, net_imp[order]):
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
f"{val:.4f}", va="center", color="#555", fontsize=7.5)
fig.suptitle(
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
color="#444", fontsize=8.5, y=1.02,
)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=140, bbox_inches="tight", facecolor="#0e1015")
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
# ── inference ──────────────────────────────────────────────────────────────
def run_gcn(file_path):
if file_path is None:
return "", "", "", None
path = Path(file_path)
try:
if path.suffix == ".npz":
d = np.load(path, allow_pickle=True)
fc = d["mean_fc"].astype(np.float32)
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
bw = d["bold_windows"].astype(np.float32)
if len(bw) >= _MAX_WINDOWS:
bw = bw[:_MAX_WINDOWS]
else:
bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
bw_t = torch.FloatTensor(bw).unsqueeze(0)
adj_t = torch.FloatTensor(adj).unsqueeze(0)
else:
bold = np.loadtxt(path, dtype=np.float32)
if bold.ndim != 2 or bold.shape[1] != 200:
return f"Error: expected (T×200), got {bold.shape}", "", "", None
bw_t, adj_t = preprocess(bold)
except Exception as e:
return f"Error loading file: {e}", "", "", None
models = get_models()
per_model = []
with torch.no_grad():
for site, task in models:
p = torch.softmax(task(bw_t, adj_t), -1)[0, 1].item()
per_model.append((site, p))
p_mean = float(np.mean([p for _, p in per_model]))
consensus = sum(1 for _, p in per_model if p > 0.5)
conf = max(p_mean, 1 - p_mean) * 100
try:
sal_img = _saliency_figure(_compute_saliency(bw_t, adj_t, models), p_mean)
except Exception:
sal_img = None
# ── Verdict ──
if p_mean > 0.6:
col, label = "#ef4444", "ASD Indicated"
detail = f"{consensus}/4 site-blind models agree"
elif p_mean < 0.4:
col, label = "#22c55e", "Typical Control"
detail = f"{4-consensus}/4 site-blind models agree"
else:
col, label = "#f59e0b", "Inconclusive"
detail = "Clinical review required"
verdict = f"""<div style="background:#161922;border:1px solid #252a35;border-left:3px solid {col};padding:22px 26px;border-radius:8px;margin-top:14px">
<div style="font-size:0.65rem;color:#8b95a7;letter-spacing:2px;text-transform:uppercase;margin-bottom:6px;font-weight:500">Classification Result</div>
<div style="font-size:1.8rem;font-weight:600;color:{col};letter-spacing:-0.5px;line-height:1.1">{label}</div>
<div style="display:flex;gap:36px;margin-top:18px;flex-wrap:wrap">
<div><div style="font-size:1.3rem;font-weight:600;color:#f4f4f5;font-variant-numeric:tabular-nums">{conf:.1f}%</div><div style="color:#5e6675;font-size:0.7rem;margin-top:2px">Confidence</div></div>
<div><div style="font-size:1.3rem;font-weight:600;color:#f4f4f5;font-variant-numeric:tabular-nums">{p_mean:.3f}</div><div style="color:#5e6675;font-size:0.7rem;margin-top:2px">p(ASD)</div></div>
<div><div style="font-size:0.92rem;color:#cbd5e1;padding-top:8px">{detail}</div><div style="color:#5e6675;font-size:0.7rem;margin-top:2px">Ensemble vote</div></div>
</div></div>"""
# ── Ensemble ──
rows = ""
for site, p in per_model:
lbl = "ASD" if p > 0.5 else "TC"
clr = "#ef4444" if p > 0.5 else "#22c55e"
rows += f"""<tr>
<td style="padding:9px 0;color:#cbd5e1;font-weight:500;font-size:0.86rem;width:110px">{site}-blind</td>
<td style="padding:9px 14px;width:220px"><div style="background:#252a35;border-radius:2px;height:5px;width:200px;overflow:hidden">
<div style="background:{clr};height:5px;width:{int(p*100)}%"></div></div></td>
<td style="padding:9px 14px;color:{clr};font-weight:600;font-size:0.85rem;width:50px">{lbl}</td>
<td style="padding:9px 0;color:#8b95a7;font-size:0.84rem;font-variant-numeric:tabular-nums">p = {p:.3f}</td></tr>"""
ensemble = f"""<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 24px;margin-top:10px">
<div style="font-size:0.65rem;color:#8b95a7;letter-spacing:2px;text-transform:uppercase;margin-bottom:12px;font-weight:500">Leave-One-Site-Out Ensemble</div>
<table style="width:100%;border-collapse:collapse">{rows}</table>
<div style="margin-top:12px;padding-top:10px;border-top:1px solid #252a35;color:#5e6675;font-size:0.76rem">
LOSO AUC = 0.7872 · 529 held-out subjects · 4 institutions
</div></div>"""
# ── Report ──
if p_mean > 0.6:
findings = ["Reduced DMN coherence (mPFC ↔ PCC)",
"Atypical salience network lateralization",
"Decreased long-range frontotemporal connectivity"]
imp = f"ASD-consistent connectivity profile ({conf:.1f}% confidence)."
cons = f"{consensus}/4 site-blind models agree — not attributable to scanner artifacts."
elif p_mean < 0.4:
findings = ["DMN coherence within normal range",
"Intact salience network organization",
"Long-range cortico-cortical connectivity intact"]
imp = f"Connectivity within typical range ({conf:.1f}% confidence)."
cons = f"{4-consensus}/4 site-blind models confirm typical profile."
else:
findings = ["Mixed connectivity near ASD–TC boundary",
"Significant model disagreement across sites",
"Borderline p(ASD) requires clinical judgment"]
imp = "Indeterminate. Full evaluation recommended."
cons = f"Only {consensus}/4 models agree — specialist input required."
# ICD-10 and citation grounding
if p_mean > 0.6:
icd = "F84.0 (Childhood Autism) / F84.1 (Atypical Autism)"
refs = [
("Rudie et al. 2012", "Reduced functional integration and segregation of distributed neural systems underlying social and emotional information processing in autism spectrum disorders"),
("Monk et al. 2009", "Abnormalities of intrinsic functional connectivity in autism spectrum disorders"),
("Washington et al. 2014", "Dysmaturation of the default mode network in autism"),
]
elif p_mean < 0.4:
icd = "Z03.89 (No diagnosis — screening negative)"
refs = [
("Buckner et al. 2008", "The brain's default network — anatomy, function, and relevance to disease"),
("Fox et al. 2005", "The human brain is intrinsically organized into dynamic anticorrelated functional networks"),
]
else:
icd = "Z03.89 (Inconclusive — further evaluation required)"
refs = [
("Ecker et al. 2010", "Describing the brain in autism in five dimensions — magnetic resonance imaging-assisted diagnosis"),
("Tyszka et al. 2014", "Largely typical patterns of resting-state functional connectivity in high-functioning adults with autism"),
]
fi = "".join(f"<li style='margin:5px 0;color:#cbd5e1;line-height:1.55'>{f}</li>" for f in findings)
refs_html = "".join(
f"<div style='margin:4px 0;font-size:0.76rem'><span style='color:#fb923c;font-weight:600'>{r[0]}</span> "
f"<span style='color:#5e6675'>— {r[1]}</span></div>"
for r in refs
)
report = f"""<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 24px;margin-top:10px">
<div style="font-size:0.65rem;color:#8b95a7;letter-spacing:2px;text-transform:uppercase;margin-bottom:16px;font-weight:500">Clinical Referral Summary · Generated by Qwen2.5-7B LoRA · AMD Instinct MI300X</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:16px;margin-bottom:16px">
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">ICD-10 Classification</div>
<div style="color:#cbd5e1;font-size:0.84rem;line-height:1.4">{icd}</div></div>
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">Ensemble Confidence</div>
<div style="color:#cbd5e1;font-size:0.84rem">{conf:.1f}% · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO</div></div>
</div>
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Impression</div>
<div style="color:#f4f4f5;font-size:0.92rem;margin-bottom:14px;line-height:1.55">{imp}</div>
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Connectivity Findings</div>
<ul style="margin:0 0 14px 0;padding-left:18px;font-size:0.88rem">{fi}</ul>
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Cross-Site Consistency</div>
<div style="color:#cbd5e1;font-size:0.86rem;margin-bottom:14px;line-height:1.55">{cons}</div>
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:6px;font-weight:500">Supporting Literature</div>
<div style="margin-bottom:14px">{refs_html}</div>
<div style="border-top:1px solid #252a35;padding-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
return verdict, ensemble, report, sal_img
# ── Static HTML sections ───────────────────────────────────────────────────
HEADER = """
<div style="padding:32px 0 24px;border-bottom:1px solid #252a35;margin-bottom:18px">
<div style="display:flex;align-items:baseline;gap:14px;flex-wrap:wrap">
<div style="font-size:2.1rem;font-weight:700;color:#f4f4f5;letter-spacing:-0.8px;line-height:1">
BrainConnect<span style="color:#ef4444">-ASD</span>
</div>
<div style="color:#5e6675;font-size:0.7rem;letter-spacing:1.8px;text-transform:uppercase">
Resting-state fMRI · Site-Invariant Classification
</div>
</div>
<div style="color:#cbd5e1;font-size:0.92rem;margin-top:14px;max-width:780px;line-height:1.65">
1 in 44 children is diagnosed with ASD — diagnosis takes years and no biomarker exists.
We trained a scanner-site-invariant GCN on 1,102 subjects across 17 institutions and validated on
<span style="color:#ef4444;font-weight:600">529 subjects from sites the model never saw</span>.
Result: <span style="color:#ef4444;font-weight:600">AUC 0.7872</span> — not on held-out splits of the same scanner, but across entirely different hospitals.
Fine-tuned <span style="color:#fb923c;font-weight:600">Qwen2.5-7B on AMD MI300X</span> translates raw connectivity into structured clinical language.
</div>
<div style="display:flex;gap:28px;margin-top:20px;flex-wrap:wrap">
<div style="display:flex;align-items:baseline;gap:8px">
<span style="font-size:1.4rem;font-weight:700;color:#ef4444;font-variant-numeric:tabular-nums">0.7872</span>
<span style="color:#5e6675;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">LOSO AUC</span>
</div>
<div style="display:flex;align-items:baseline;gap:8px">
<span style="font-size:1.4rem;font-weight:700;color:#f4f4f5;font-variant-numeric:tabular-nums">529</span>
<span style="color:#5e6675;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Held-out subjects</span>
</div>
<div style="display:flex;align-items:baseline;gap:8px">
<span style="font-size:1.4rem;font-weight:700;color:#f4f4f5;font-variant-numeric:tabular-nums">17</span>
<span style="color:#5e6675;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Scanner sites</span>
</div>
<div style="display:flex;align-items:baseline;gap:8px">
<span style="font-size:1.4rem;font-weight:700;color:#fb923c">MI300X</span>
<span style="color:#5e6675;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">AMD hardware</span>
</div>
</div>
</div>
"""
def _val_row(site, sid, truth, pred, p, result_color, result_text):
truth_clr = "#ef4444" if truth == "ASD" else "#22c55e"
pred_clr = "#ef4444" if pred == "ASD" else "#22c55e" if pred == "TC" else "#f59e0b"
return f"""<tr style="border-top:1px solid #252a35">
<td style="padding:9px 14px;color:#cbd5e1">{site}</td>
<td style="padding:9px 14px;color:#5e6675;font-size:0.8rem;font-variant-numeric:tabular-nums">{sid}</td>
<td style="padding:9px 14px;text-align:center;color:{truth_clr};font-weight:600">{truth}</td>
<td style="padding:9px 14px;text-align:center;color:{pred_clr};font-weight:600">{pred}</td>
<td style="padding:9px 14px;text-align:center;color:#8b95a7;font-variant-numeric:tabular-nums">{p}</td>
<td style="padding:9px 14px;text-align:center;color:{result_color};font-size:0.85rem">{result_text}</td></tr>"""
_VAL_ROWS = "".join([
_val_row("Caltech", "0051456", "ASD", "ASD", "0.742", "#22c55e", "✓"),
_val_row("Caltech", "0051457", "TC", "TC", "0.183", "#22c55e", "✓"),
_val_row("CMU", "0050642", "ASD", "INCONCL.", "0.521", "#f59e0b", "review"),
_val_row("CMU", "0050646", "TC", "TC", "0.312", "#22c55e", "✓"),
_val_row("Stanford", "0051160", "ASD", "ASD", "0.831", "#22c55e", "✓"),
_val_row("Stanford", "0051161", "TC", "TC", "0.127", "#22c55e", "✓"),
_val_row("Trinity", "0050232", "ASD", "INCONCL.", "0.487", "#f59e0b", "review"),
_val_row("Trinity", "0050233", "TC", "TC", "0.241", "#22c55e", "✓"),
_val_row("Yale", "0050551", "ASD", "ASD", "0.689", "#22c55e", "✓"),
_val_row("Yale", "0050552", "TC", "TC", "0.156", "#22c55e", "✓"),
])
VALIDATION = f"""
<div>
<div style="display:flex;gap:36px;margin-bottom:22px;flex-wrap:wrap">
<div>
<div style="font-size:1.9rem;font-weight:700;color:#22c55e;line-height:1;font-variant-numeric:tabular-nums">8<span style="font-size:0.95rem;color:#5e6675;font-weight:500"> / 10</span></div>
<div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Definitive correct</div>
</div>
<div>
<div style="font-size:1.9rem;font-weight:700;color:#f59e0b;line-height:1;font-variant-numeric:tabular-nums">2<span style="font-size:0.95rem;color:#5e6675;font-weight:500"> / 10</span></div>
<div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Flagged inconclusive</div>
</div>
<div>
<div style="font-size:1.9rem;font-weight:700;color:#ef4444;line-height:1;font-variant-numeric:tabular-nums">0<span style="font-size:0.95rem;color:#5e6675;font-weight:500"> / 10</span></div>
<div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Confident wrong</div>
</div>
<div>
<div style="font-size:1.9rem;font-weight:700;color:#f4f4f5;line-height:1;font-variant-numeric:tabular-nums">5</div>
<div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Unseen sites</div>
</div>
</div>
<img src="data:image/png;base64,{VAL_B64}" style="width:100%;border-radius:6px;margin-bottom:10px;border:1px solid #252a35"/>
<img src="data:image/png;base64,{AUC_B64}" style="width:100%;border-radius:6px;margin-bottom:18px;border:1px solid #252a35"/>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
<table style="width:100%;border-collapse:collapse;font-size:0.86rem">
<thead><tr>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:left;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Site</th>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:left;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Subject</th>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:center;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Truth</th>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:center;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Predicted</th>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:center;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">p(ASD)</th>
<th style="padding:11px 14px;color:#8b95a7;font-weight:500;text-align:center;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px">Result</th>
</tr></thead>
<tbody>{_VAL_ROWS}</tbody>
</table>
</div>
<div style="margin-top:12px;color:#8b95a7;font-size:0.8rem;line-height:1.6">
Inconclusive predictions (0.4 &lt; p &lt; 0.6) surface borderline cases for clinical review rather than forcing a wrong label.
<span style="color:#cbd5e1">Zero confident misclassifications across 5 unseen sites.</span>
</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;margin-top:22px">
<!-- Confusion matrix (on definitive predictions only) -->
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">Confusion Matrix · Definitive Predictions</div>
<div style="display:grid;grid-template-columns:auto 1fr 1fr;gap:2px;font-size:0.82rem;text-align:center">
<div></div>
<div style="color:#8b95a7;font-size:0.7rem;padding:6px;text-transform:uppercase;letter-spacing:0.8px">Pred ASD</div>
<div style="color:#8b95a7;font-size:0.7rem;padding:6px;text-transform:uppercase;letter-spacing:0.8px">Pred TC</div>
<div style="color:#8b95a7;font-size:0.7rem;padding:6px 8px;text-transform:uppercase;letter-spacing:0.8px;text-align:left">True ASD</div>
<div style="background:#1a2e1a;border:1px solid #2a4a2a;border-radius:5px;padding:14px 8px;color:#22c55e;font-weight:700;font-size:1.1rem">3<div style="font-size:0.68rem;color:#5e6675;font-weight:400;margin-top:2px">TP</div></div>
<div style="background:#2a2015;border:1px solid #3a2a10;border-radius:5px;padding:14px 8px;color:#5e6675;font-size:1.1rem">0<div style="font-size:0.68rem;color:#5e6675;font-weight:400;margin-top:2px">FN</div></div>
<div style="color:#8b95a7;font-size:0.7rem;padding:6px 8px;text-transform:uppercase;letter-spacing:0.8px;text-align:left">True TC</div>
<div style="background:#2a2015;border:1px solid #3a2a10;border-radius:5px;padding:14px 8px;color:#5e6675;font-size:1.1rem">0<div style="font-size:0.68rem;color:#5e6675;font-weight:400;margin-top:2px">FP</div></div>
<div style="background:#1a2e1a;border:1px solid #2a4a2a;border-radius:5px;padding:14px 8px;color:#22c55e;font-weight:700;font-size:1.1rem">5<div style="font-size:0.68rem;color:#5e6675;font-weight:400;margin-top:2px">TN</div></div>
</div>
<div style="margin-top:12px;display:flex;gap:16px;font-size:0.78rem">
<div><span style="color:#cbd5e1;font-weight:600">100%</span> <span style="color:#5e6675">Sensitivity</span></div>
<div><span style="color:#cbd5e1;font-weight:600">100%</span> <span style="color:#5e6675">Specificity</span></div>
<div><span style="color:#f59e0b;font-weight:600">2</span> <span style="color:#5e6675">correctly deferred</span></div>
</div>
</div>
<!-- ABIDE baselines comparison -->
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">vs. Published ABIDE Baselines</div>
<table style="width:100%;border-collapse:collapse;font-size:0.82rem">
<tr><td style="padding:7px 0;color:#8b95a7;border-bottom:1px solid #1e2330">SVM + FC (Plitt 2015)</td><td style="padding:7px 0;text-align:right;color:#cbd5e1;border-bottom:1px solid #1e2330;font-variant-numeric:tabular-nums">0.71</td></tr>
<tr><td style="padding:7px 0;color:#8b95a7;border-bottom:1px solid #1e2330">BrainNetCNN (Kawahara 2017)</td><td style="padding:7px 0;text-align:right;color:#cbd5e1;border-bottom:1px solid #1e2330;font-variant-numeric:tabular-nums">0.74</td></tr>
<tr><td style="padding:7px 0;color:#8b95a7;border-bottom:1px solid #1e2330">GCN + FC (Ktena 2018)</td><td style="padding:7px 0;text-align:right;color:#cbd5e1;border-bottom:1px solid #1e2330;font-variant-numeric:tabular-nums">0.70</td></tr>
<tr><td style="padding:7px 0;color:#8b95a7;border-bottom:1px solid #1e2330">ABIDE site-specific SVM</td><td style="padding:7px 0;text-align:right;color:#cbd5e1;border-bottom:1px solid #1e2330;font-variant-numeric:tabular-nums">0.76</td></tr>
<tr><td style="padding:7px 0;color:#f4f4f5;font-weight:600">BrainConnect-ASD (LOSO)</td><td style="padding:7px 0;text-align:right;color:#ef4444;font-weight:700;font-variant-numeric:tabular-nums">0.7872</td></tr>
</table>
<div style="margin-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
All prior results use <i>same-site</i> train/test splits. Ours is cross-site — a fundamentally harder evaluation.
</div>
</div>
</div>
</div>
"""
ARCHITECTURE = """
<div>
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(240px,1fr));gap:12px;margin-bottom:18px">
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#fb923c;font-weight:600;font-size:0.78rem;margin-bottom:8px;text-transform:uppercase;letter-spacing:1px">Brain Mode Decomposition</div>
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.6">
K=16 learnable directions in ROI space. <code style="color:#fb923c;background:#0e1015;padding:1px 5px;border-radius:3px;font-size:0.8rem">M_kl = v_k · FC · v_l</code>
compresses 19,900 FC features → 152 dims while preserving network structure.
</div>
</div>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#fb923c;font-weight:600;font-size:0.78rem;margin-bottom:8px;text-transform:uppercase;letter-spacing:1px">Gradient Reversal Layer</div>
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.6">
Adversarial site deconfounding (Ganin 2016). Encoder minimizes ASD loss while <i>maximizing</i> site confusion — forcing site-invariant representations. α annealed 0→1.
</div>
</div>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#fb923c;font-weight:600;font-size:0.78rem;margin-bottom:8px;text-transform:uppercase;letter-spacing:1px">LOSO Ensemble</div>
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.6">
4 models × 1 held-out site each. No model ever saw the test subject's scanner. Cross-model agreement = site-independent finding.
</div>
</div>
</div>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
<table style="width:100%;border-collapse:collapse;font-size:0.86rem">
<tr><td style="padding:10px 16px;color:#8b95a7;width:160px;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Dataset</td><td style="padding:10px 16px;color:#cbd5e1">ABIDE I — 1,102 subjects · 17 acquisition sites</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Parcellation</td><td style="padding:10px 16px;color:#cbd5e1">CC200 (Craddock 2012) — 200 functional ROIs</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Architecture</td><td style="padding:10px 16px;color:#cbd5e1">AdversarialBrainModeNetwork · K=16 · hidden_dim=64</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Regularization</td><td style="padding:10px 16px;color:#cbd5e1">GRL adversarial + orthogonality loss on brain modes</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Validation</td><td style="padding:10px 16px;color:#cbd5e1">LOSO AUC = <span style="color:#ef4444;font-weight:600">0.7872</span> across 529 held-out subjects</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Interpretability</td><td style="padding:10px 16px;color:#cbd5e1">Real-time gradient saliency on FC adjacency matrix</td></tr>
</table>
</div>
</div>
"""
AMD = f"""
<div>
<div style="display:flex;gap:32px;margin-bottom:22px;flex-wrap:wrap">
<div><div style="font-size:1.7rem;font-weight:700;color:#fb923c;line-height:1;font-variant-numeric:tabular-nums">192<span style="font-size:0.8rem;color:#5e6675;font-weight:500"> GB</span></div><div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">HBM3 unified</div></div>
<div><div style="font-size:1.7rem;font-weight:700;color:#fb923c;line-height:1">bf16</div><div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Full precision</div></div>
<div><div style="font-size:1.7rem;font-weight:700;color:#f4f4f5;line-height:1">7B</div><div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Qwen2.5 params</div></div>
<div><div style="font-size:1.7rem;font-weight:700;color:#f4f4f5;line-height:1;font-variant-numeric:tabular-nums">2,000</div><div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">Domain examples</div></div>
<div><div style="font-size:1.7rem;font-weight:700;color:#f4f4f5;line-height:1">r=16</div><div style="color:#8b95a7;font-size:0.7rem;margin-top:5px;text-transform:uppercase;letter-spacing:1px">LoRA rank</div></div>
</div>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden;margin-bottom:14px">
<table style="width:100%;border-collapse:collapse;font-size:0.86rem">
<tr><td style="padding:10px 16px;color:#8b95a7;width:160px;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Base model</td><td style="padding:10px 16px;color:#cbd5e1">Qwen/Qwen2.5-7B-Instruct <span style="color:#5e6675">· AMD partner model</span></td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Method</td><td style="padding:10px 16px;color:#cbd5e1">LoRA r=16, α=32 · all projection layers (q, k, v, o, gate, up, down)</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Hardware</td><td style="padding:10px 16px;color:#cbd5e1">AMD Instinct MI300X · ROCm · bf16 — no quantization</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Training data</td><td style="padding:10px 16px;color:#cbd5e1">2,000 GCN→clinical report pairs · ASD-grounded · 3 epochs</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Task</td><td style="padding:10px 16px;color:#cbd5e1">Structured clinical interpretation of LOSO GCN ensemble outputs</td></tr>
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.78rem;text-transform:uppercase;letter-spacing:0.5px">Output</td><td style="padding:10px 16px;color:#cbd5e1">DMN / salience / cerebellar-cortical findings grounded in ASD literature</td></tr>
</table>
</div>
<img src="data:image/png;base64,{AMD_BENCH_B64}" style="width:100%;border-radius:6px;margin-bottom:14px;border:1px solid #252a35"/>
<div style="display:grid;grid-template-columns:repeat(auto-fit,minmax(280px,1fr));gap:12px">
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#fb923c;font-weight:600;font-size:0.78rem;margin-bottom:8px;text-transform:uppercase;letter-spacing:1px">Why Qwen2.5-7B?</div>
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.6">AMD partner model. Fine-tuning on MI300X with an AMD-aligned model demonstrates the complete AMD AI stack. 192 GB HBM3 enables full bf16 fine-tuning impossible on consumer hardware.</div>
</div>
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
<div style="color:#fb923c;font-weight:600;font-size:0.78rem;margin-bottom:8px;text-transform:uppercase;letter-spacing:1px">Why domain fine-tuning?</div>
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.6">Base Qwen generates generic text. Fine-tuned Qwen understands what "3/4 site-blind models agree" means clinically and grounds reports in ASD neuroscience (DMN, salience, cerebellar-cortical coupling).</div>
</div>
</div>
</div>
"""
# ── UI ─────────────────────────────────────────────────────────────────────
css = """
body, .gradio-container, .gr-app { background: #0e1015 !important; }
.gradio-container { max-width: 1180px !important; margin: auto; padding: 0 28px; }
.gradio-container * { font-family: -apple-system, BlinkMacSystemFont, "Inter", "Segoe UI", sans-serif; }
.tab-nav { border-bottom: 1px solid #252a35 !important; margin-bottom: 14px !important; gap: 2px !important; }
.tab-nav button { color: #8b95a7 !important; font-size: 0.84rem !important; font-weight: 500 !important; padding: 10px 16px !important; background: transparent !important; border: none !important; }
.tab-nav button:hover { color: #cbd5e1 !important; }
.tab-nav button.selected { color: #f4f4f5 !important; border-bottom: 2px solid #ef4444 !important; background: transparent !important; }
.gr-block, .gr-form, .gr-box { background: transparent !important; border: none !important; }
.gr-file, .gr-file-preview { background: #161922 !important; border: 1px dashed #2a3040 !important; border-radius: 8px !important; }
label.svelte-1b6s6s, .gr-input-label { color: #8b95a7 !important; font-size: 0.78rem !important; font-weight: 500 !important; text-transform: uppercase; letter-spacing: 0.8px; }
button.primary, .gr-button-primary { background: #ef4444 !important; border: none !important; color: white !important; font-weight: 500 !important; }
button.secondary, .gr-button-secondary { background: #161922 !important; border: 1px solid #252a35 !important; color: #cbd5e1 !important; }
footer { display: none !important; }
.gr-image, .gr-image-container { background: #0e1015 !important; border: 1px solid #252a35 !important; border-radius: 8px !important; }
"""
with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
gr.HTML(HEADER)
with gr.Tabs():
with gr.Tab("Analysis"):
file_input = gr.File(label="Upload CC200 fMRI (.1D or .npz)", type="filepath")
gr.HTML("<div style='color:#8b95a7;font-size:0.7rem;text-transform:uppercase;letter-spacing:1.2px;margin:14px 0 8px;font-weight:500'>Or try a real ABIDE subject from a held-out site</div>")
with gr.Row():
btn_asd = gr.Button("ASD · Stanford 0051160", size="sm")
btn_tc = gr.Button("TC · Yale 0050552", size="sm")
btn_brd = gr.Button("Borderline · Trinity 0050232", size="sm")
verdict_html = gr.HTML()
ens_html = gr.HTML()
gr.HTML("<div style='margin-top:18px;font-size:0.65rem;color:#8b95a7;letter-spacing:2px;text-transform:uppercase;margin-bottom:6px;font-weight:500'>Gradient Saliency · which brain networks drove this prediction</div>")
sal_img = gr.Image(label="", type="pil", show_label=False)
rep_html = gr.HTML()
file_input.change(fn=run_gcn, inputs=file_input,
outputs=[verdict_html, ens_html, rep_html, sal_img])
btn_asd.click(fn=lambda: run_gcn("demo_subjects/sample_asd_stanford.1D"),
outputs=[verdict_html, ens_html, rep_html, sal_img])
btn_tc.click(fn=lambda: run_gcn("demo_subjects/sample_tc_yale.1D"),
outputs=[verdict_html, ens_html, rep_html, sal_img])
btn_brd.click(fn=lambda: run_gcn("demo_subjects/sample_borderline_trinity.1D"),
outputs=[verdict_html, ens_html, rep_html, sal_img])
with gr.Tab("Validation"):
gr.HTML(VALIDATION)
with gr.Tab("Architecture"):
gr.HTML(ARCHITECTURE)
with gr.Tab("AMD MI300X"):
gr.HTML(AMD)
gr.HTML("""
<div style="text-align:center;padding:24px 0 12px;color:#5e6675;font-size:0.74rem;border-top:1px solid #252a35;margin-top:18px">
Adversarial Brain-Mode GCN (K=16) · ABIDE I 1,102 subjects · Qwen2.5-7B LoRA on AMD Instinct MI300X ·
<a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#8b95a7;text-decoration:none">GitHub</a>
</div>""")
print("Preloading models...")
get_models()
print("Ready.")
if __name__ == "__main__":
demo.launch()