Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -17,6 +17,11 @@ _STEP = 3
|
|
| 17 |
_MAX_WINDOWS = 30
|
| 18 |
_FC_THRESHOLD = 0.2
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
_CKPTS = {
|
| 21 |
"NYU": Path("checkpoints/nyu.ckpt"),
|
| 22 |
"USM": Path("checkpoints/usm.ckpt"),
|
|
@@ -89,44 +94,86 @@ def _saliency_figure(sal, p_mean):
|
|
| 89 |
import matplotlib.pyplot as plt
|
| 90 |
from PIL import Image
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 99 |
fig.patch.set_facecolor("#0d0d0d")
|
| 100 |
|
|
|
|
| 101 |
ax = axes[0]
|
| 102 |
-
ax.set_facecolor("#111")
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
ax.
|
| 107 |
-
ax.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 109 |
cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
|
| 110 |
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
|
|
|
|
| 111 |
|
|
|
|
| 112 |
ax2 = axes[1]
|
| 113 |
-
ax2.set_facecolor("#111")
|
| 114 |
-
ax2.
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
ax2.
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
ax2.invert_yaxis()
|
| 120 |
-
for sp in ["top", "right"]:
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
fig.suptitle(
|
| 124 |
-
f"Gradient Saliency · p(ASD)={p_mean:.3f} · {len(_models)}-model LOSO ensemble",
|
| 125 |
-
color="#
|
| 126 |
)
|
| 127 |
plt.tight_layout()
|
| 128 |
buf = io.BytesIO()
|
| 129 |
-
plt.savefig(buf, format="png", dpi=
|
| 130 |
plt.close(fig)
|
| 131 |
buf.seek(0)
|
| 132 |
return Image.open(buf).copy()
|
|
@@ -260,6 +307,14 @@ HEADER = """
|
|
| 260 |
<div style="color:#333;font-size:0.72rem;letter-spacing:4px;text-transform:uppercase;margin-top:10px">
|
| 261 |
Clinical AI · Resting-state fMRI · Scanner-Site-Invariant Classification
|
| 262 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
<div style="display:flex;gap:0;margin-top:28px;border:1px solid #1a1a1a;border-radius:12px;overflow:hidden;max-width:700px">
|
| 264 |
<div style="padding:20px 32px;flex:1;border-right:1px solid #1a1a1a;min-width:120px">
|
| 265 |
<div style="font-size:2.2rem;font-weight:900;color:#e63946;line-height:1">0.7872</div>
|
|
|
|
| 17 |
_MAX_WINDOWS = 30
|
| 18 |
_FC_THRESHOLD = 0.2
|
| 19 |
|
| 20 |
+
# CC200 atlas (Craddock 2012) → approximate Yeo 7-network parcellation
|
| 21 |
+
_NET_NAMES = ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"]
|
| 22 |
+
_NET_BOUNDS = [0, 38, 69, 99, 137, 165, 180, 200]
|
| 23 |
+
_NET_COLORS = ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"]
|
| 24 |
+
|
| 25 |
_CKPTS = {
|
| 26 |
"NYU": Path("checkpoints/nyu.ckpt"),
|
| 27 |
"USM": Path("checkpoints/usm.ckpt"),
|
|
|
|
| 94 |
import matplotlib.pyplot as plt
|
| 95 |
from PIL import Image
|
| 96 |
|
| 97 |
+
n_nets = len(_NET_NAMES)
|
| 98 |
+
|
| 99 |
+
# Aggregate 200×200 saliency → 7×7 network-level matrix
|
| 100 |
+
net_sal = np.zeros((n_nets, n_nets))
|
| 101 |
+
for i, (s1, e1) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
|
| 102 |
+
for j, (s2, e2) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
|
| 103 |
+
net_sal[i, j] = sal[s1:e1, s2:e2].mean()
|
| 104 |
+
|
| 105 |
+
# Network importance: mean outgoing + incoming saliency per network
|
| 106 |
+
net_imp = np.array([
|
| 107 |
+
sal[s:e, :].mean() + sal[:, s:e].mean()
|
| 108 |
+
for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
|
| 109 |
+
])
|
| 110 |
|
| 111 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
|
| 112 |
fig.patch.set_facecolor("#0d0d0d")
|
| 113 |
|
| 114 |
+
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
|
| 115 |
ax = axes[0]
|
| 116 |
+
ax.set_facecolor("#111")
|
| 117 |
+
im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
|
| 118 |
+
ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=11, pad=14, fontweight="bold")
|
| 119 |
+
|
| 120 |
+
ax.set_xticks(range(n_nets))
|
| 121 |
+
ax.set_yticks(range(n_nets))
|
| 122 |
+
ax.set_xticklabels(_NET_NAMES, rotation=40, ha="right", fontsize=9, color="#ccc")
|
| 123 |
+
ax.set_yticklabels(_NET_NAMES, fontsize=9, color="#ccc")
|
| 124 |
+
ax.tick_params(colors="#555", length=0)
|
| 125 |
+
for sp in ax.spines.values():
|
| 126 |
+
sp.set_color("#222")
|
| 127 |
+
|
| 128 |
+
# Boundary lines between networks
|
| 129 |
+
for k in range(1, n_nets):
|
| 130 |
+
ax.axhline(k - 0.5, color="#2a2a2a", lw=1.0)
|
| 131 |
+
ax.axvline(k - 0.5, color="#2a2a2a", lw=1.0)
|
| 132 |
+
|
| 133 |
+
# Annotate each cell with its value
|
| 134 |
+
vmax = net_sal.max()
|
| 135 |
+
for i in range(n_nets):
|
| 136 |
+
for j in range(n_nets):
|
| 137 |
+
txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#555"
|
| 138 |
+
ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
|
| 139 |
+
fontsize=6.5, color=txt_color)
|
| 140 |
+
|
| 141 |
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 142 |
cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
|
| 143 |
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
|
| 144 |
+
cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=7.5)
|
| 145 |
|
| 146 |
+
# ── Right: network importance bar chart ────────────────────────────────
|
| 147 |
ax2 = axes[1]
|
| 148 |
+
ax2.set_facecolor("#111")
|
| 149 |
+
ax2.tick_params(colors="#555", labelsize=9)
|
| 150 |
+
|
| 151 |
+
order = net_imp.argsort()[::-1]
|
| 152 |
+
bars = ax2.barh(range(n_nets), net_imp[order],
|
| 153 |
+
color=[_NET_COLORS[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
|
| 154 |
+
ax2.set_yticks(range(n_nets))
|
| 155 |
+
ax2.set_yticklabels([_NET_NAMES[i] for i in order], fontsize=9.5, color="#ddd")
|
| 156 |
+
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
|
| 157 |
+
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
|
| 158 |
ax2.invert_yaxis()
|
| 159 |
+
for sp in ["top", "right"]:
|
| 160 |
+
ax2.spines[sp].set_visible(False)
|
| 161 |
+
for sp in ["bottom", "left"]:
|
| 162 |
+
ax2.spines[sp].set_color("#222")
|
| 163 |
+
|
| 164 |
+
# Value labels on bars
|
| 165 |
+
x_max = net_imp.max()
|
| 166 |
+
for bar, val in zip(bars, net_imp[order]):
|
| 167 |
+
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
|
| 168 |
+
f"{val:.4f}", va="center", color="#555", fontsize=7.5)
|
| 169 |
|
| 170 |
fig.suptitle(
|
| 171 |
+
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 172 |
+
color="#444", fontsize=8.5, y=1.02,
|
| 173 |
)
|
| 174 |
plt.tight_layout()
|
| 175 |
buf = io.BytesIO()
|
| 176 |
+
plt.savefig(buf, format="png", dpi=140, bbox_inches="tight", facecolor="#0d0d0d")
|
| 177 |
plt.close(fig)
|
| 178 |
buf.seek(0)
|
| 179 |
return Image.open(buf).copy()
|
|
|
|
| 307 |
<div style="color:#333;font-size:0.72rem;letter-spacing:4px;text-transform:uppercase;margin-top:10px">
|
| 308 |
Clinical AI · Resting-state fMRI · Scanner-Site-Invariant Classification
|
| 309 |
</div>
|
| 310 |
+
<div style="color:#444;font-size:0.93rem;margin-top:18px;max-width:720px;line-height:1.75">
|
| 311 |
+
1 in 44 children is diagnosed with ASD. Today, diagnosis takes years of behavioral observation —
|
| 312 |
+
no biomarker exists. We trained a scanner-site-invariant GCN on 1,102 subjects across 17 institutions
|
| 313 |
+
and validated it on <span style="color:#e63946;font-weight:700">529 subjects the model never saw, from sites it was never trained on</span>.
|
| 314 |
+
The result: <span style="color:#e63946;font-weight:700">AUC 0.7872</span> — not on held-out splits of the same scanner, but
|
| 315 |
+
across entirely different hospitals. Fine-tuned <span style="color:#f4a261;font-weight:700">Qwen2.5-7B on AMD MI300X</span>
|
| 316 |
+
then translates raw connectivity patterns into structured clinical language a clinician can act on.
|
| 317 |
+
</div>
|
| 318 |
<div style="display:flex;gap:0;margin-top:28px;border:1px solid #1a1a1a;border-radius:12px;overflow:hidden;max-width:700px">
|
| 319 |
<div style="padding:20px 32px;flex:1;border-right:1px solid #1a1a1a;min-width:120px">
|
| 320 |
<div style="font-size:2.2rem;font-weight:900;color:#e63946;line-height:1">0.7872</div>
|