Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -88,10 +88,23 @@ def _compute_saliency(bw_t, adj_t, models):
|
|
| 88 |
sal = np.mean(maps, axis=0)
|
| 89 |
return (sal + sal.T) / 2
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def _saliency_figure(sal, p_mean):
|
| 92 |
import matplotlib
|
| 93 |
matplotlib.use("Agg")
|
| 94 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
| 95 |
from PIL import Image
|
| 96 |
|
| 97 |
n_nets = len(_NET_NAMES)
|
|
@@ -108,8 +121,13 @@ def _saliency_figure(sal, p_mean):
|
|
| 108 |
for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
|
| 109 |
])
|
| 110 |
|
| 111 |
-
fig
|
| 112 |
-
fig.patch.set_facecolor("#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
|
| 115 |
ax = axes[0]
|
|
@@ -192,6 +210,45 @@ def _saliency_figure(sal, p_mean):
|
|
| 192 |
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
|
| 193 |
f"{val:.4f}", va="center", color="#555", fontsize=7.5)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
fig.suptitle(
|
| 196 |
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 197 |
color="#444", fontsize=8.5, y=1.02,
|
|
|
|
| 88 |
sal = np.mean(maps, axis=0)
|
| 89 |
return (sal + sal.T) / 2
|
| 90 |
|
| 91 |
+
# Approximate MNI centroids for each CC200 network (mm), used for 3D brain view
|
| 92 |
+
_NET_MNI = np.array([
|
| 93 |
+
[ -1, -52, 28], # DMN (PCC)
|
| 94 |
+
[ 2, 18, 30], # Salience (dACC)
|
| 95 |
+
[ 44, 36, 28], # Frontoparietal (DLPFC)
|
| 96 |
+
[ 0, -18, 62], # Sensorimotor (SMA/M1)
|
| 97 |
+
[ 0, -82, 8], # Visual (occipital)
|
| 98 |
+
[ 28, -58, 50], # Dorsal Attn (IPS)
|
| 99 |
+
[ 14, 4, 4], # Subcortical (thalamus)
|
| 100 |
+
], dtype=np.float32)
|
| 101 |
+
|
| 102 |
def _saliency_figure(sal, p_mean):
|
| 103 |
import matplotlib
|
| 104 |
matplotlib.use("Agg")
|
| 105 |
import matplotlib.pyplot as plt
|
| 106 |
+
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
| 107 |
+
from mpl_toolkits.mplot3d.art3d import Line3DCollection
|
| 108 |
from PIL import Image
|
| 109 |
|
| 110 |
n_nets = len(_NET_NAMES)
|
|
|
|
| 121 |
for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
|
| 122 |
])
|
| 123 |
|
| 124 |
+
fig = plt.figure(figsize=(18, 5.5))
|
| 125 |
+
fig.patch.set_facecolor("#0e1015")
|
| 126 |
+
axes = [
|
| 127 |
+
fig.add_subplot(1, 3, 1),
|
| 128 |
+
fig.add_subplot(1, 3, 2),
|
| 129 |
+
fig.add_subplot(1, 3, 3, projection="3d"),
|
| 130 |
+
]
|
| 131 |
|
| 132 |
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
|
| 133 |
ax = axes[0]
|
|
|
|
| 210 |
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
|
| 211 |
f"{val:.4f}", va="center", color="#555", fontsize=7.5)
|
| 212 |
|
| 213 |
+
# ── 3D Brain Surface — top connections ────────────────────────────────────
|
| 214 |
+
ax3 = axes[2]
|
| 215 |
+
ax3.set_facecolor("#0e1015")
|
| 216 |
+
ax3.grid(False)
|
| 217 |
+
ax3.set_axis_off()
|
| 218 |
+
ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=11, pad=4, fontweight="bold")
|
| 219 |
+
|
| 220 |
+
# Transparent brain ellipsoid wireframe (MNI space approx)
|
| 221 |
+
u = np.linspace(0, 2 * np.pi, 32)
|
| 222 |
+
v = np.linspace(0, np.pi, 20)
|
| 223 |
+
ex = 68 * np.outer(np.cos(u), np.sin(v))
|
| 224 |
+
ey = 85 * np.outer(np.sin(u), np.sin(v)) - 10
|
| 225 |
+
ez = 60 * np.outer(np.ones_like(u), np.cos(v)) + 28
|
| 226 |
+
ax3.plot_wireframe(ex, ey, ez, color="#252a35", linewidth=0.25, alpha=0.45, zorder=0)
|
| 227 |
+
|
| 228 |
+
# Network nodes — size proportional to importance
|
| 229 |
+
imp_norm = (net_imp - net_imp.min()) / (net_imp.max() - net_imp.min() + 1e-9)
|
| 230 |
+
for k, (name, color) in enumerate(zip(_NET_NAMES, _NET_COLORS)):
|
| 231 |
+
x, y, z = _NET_MNI[k]
|
| 232 |
+
size = 60 + imp_norm[k] * 260
|
| 233 |
+
ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
|
| 234 |
+
edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
|
| 235 |
+
ax3.text(x, y, z + 7, name, fontsize=5.5, color=color,
|
| 236 |
+
ha="center", va="bottom", fontweight="600", zorder=6)
|
| 237 |
+
|
| 238 |
+
# Draw top-5 inter-network connections as lines, thickness ∝ saliency
|
| 239 |
+
sal_vals = [s for s, _, _ in edge_scores[:5]]
|
| 240 |
+
sal_min, sal_max = min(sal_vals), max(sal_vals) + 1e-9
|
| 241 |
+
for rank, (score, ni, nj) in enumerate(edge_scores[:5]):
|
| 242 |
+
p1, p2 = _NET_MNI[ni], _NET_MNI[nj]
|
| 243 |
+
lw = 0.8 + 2.5 * (score - sal_min) / (sal_max - sal_min)
|
| 244 |
+
alph = 0.5 + 0.45 * (score - sal_min) / (sal_max - sal_min)
|
| 245 |
+
clr = "#fb923c" if rank == 0 else "#f4f4f5"
|
| 246 |
+
ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]],
|
| 247 |
+
color=clr, linewidth=lw, alpha=alph, zorder=4)
|
| 248 |
+
|
| 249 |
+
ax3.view_init(elev=22, azim=-65)
|
| 250 |
+
ax3.set_box_aspect([1.2, 1.4, 1.0])
|
| 251 |
+
|
| 252 |
fig.suptitle(
|
| 253 |
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 254 |
color="#444", fontsize=8.5, y=1.02,
|