Yatsuiii commited on
Commit
46f07ab
·
verified ·
1 Parent(s): 3d9e4e3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +59 -2
app.py CHANGED
@@ -88,10 +88,23 @@ def _compute_saliency(bw_t, adj_t, models):
88
  sal = np.mean(maps, axis=0)
89
  return (sal + sal.T) / 2
90
 
 
 
 
 
 
 
 
 
 
 
 
91
  def _saliency_figure(sal, p_mean):
92
  import matplotlib
93
  matplotlib.use("Agg")
94
  import matplotlib.pyplot as plt
 
 
95
  from PIL import Image
96
 
97
  n_nets = len(_NET_NAMES)
@@ -108,8 +121,13 @@ def _saliency_figure(sal, p_mean):
108
  for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
109
  ])
110
 
111
- fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
112
- fig.patch.set_facecolor("#0d0d0d")
 
 
 
 
 
113
 
114
  # ── Left: 7×7 network heatmap ──────────────────────────────────────────
115
  ax = axes[0]
@@ -192,6 +210,45 @@ def _saliency_figure(sal, p_mean):
192
  ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
193
  f"{val:.4f}", va="center", color="#555", fontsize=7.5)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  fig.suptitle(
196
  f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
197
  color="#444", fontsize=8.5, y=1.02,
 
88
  sal = np.mean(maps, axis=0)
89
  return (sal + sal.T) / 2
90
 
91
+ # Approximate MNI centroids for each CC200 network (mm), used for 3D brain view
92
+ _NET_MNI = np.array([
93
+ [ -1, -52, 28], # DMN (PCC)
94
+ [ 2, 18, 30], # Salience (dACC)
95
+ [ 44, 36, 28], # Frontoparietal (DLPFC)
96
+ [ 0, -18, 62], # Sensorimotor (SMA/M1)
97
+ [ 0, -82, 8], # Visual (occipital)
98
+ [ 28, -58, 50], # Dorsal Attn (IPS)
99
+ [ 14, 4, 4], # Subcortical (thalamus)
100
+ ], dtype=np.float32)
101
+
102
  def _saliency_figure(sal, p_mean):
103
  import matplotlib
104
  matplotlib.use("Agg")
105
  import matplotlib.pyplot as plt
106
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401
107
+ from mpl_toolkits.mplot3d.art3d import Line3DCollection
108
  from PIL import Image
109
 
110
  n_nets = len(_NET_NAMES)
 
121
  for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
122
  ])
123
 
124
+ fig = plt.figure(figsize=(18, 5.5))
125
+ fig.patch.set_facecolor("#0e1015")
126
+ axes = [
127
+ fig.add_subplot(1, 3, 1),
128
+ fig.add_subplot(1, 3, 2),
129
+ fig.add_subplot(1, 3, 3, projection="3d"),
130
+ ]
131
 
132
  # ── Left: 7×7 network heatmap ──────────────────────────────────────────
133
  ax = axes[0]
 
210
  ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
211
  f"{val:.4f}", va="center", color="#555", fontsize=7.5)
212
 
213
+ # ── 3D Brain Surface — top connections ────────────────────────────────────
214
+ ax3 = axes[2]
215
+ ax3.set_facecolor("#0e1015")
216
+ ax3.grid(False)
217
+ ax3.set_axis_off()
218
+ ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=11, pad=4, fontweight="bold")
219
+
220
+ # Transparent brain ellipsoid wireframe (MNI space approx)
221
+ u = np.linspace(0, 2 * np.pi, 32)
222
+ v = np.linspace(0, np.pi, 20)
223
+ ex = 68 * np.outer(np.cos(u), np.sin(v))
224
+ ey = 85 * np.outer(np.sin(u), np.sin(v)) - 10
225
+ ez = 60 * np.outer(np.ones_like(u), np.cos(v)) + 28
226
+ ax3.plot_wireframe(ex, ey, ez, color="#252a35", linewidth=0.25, alpha=0.45, zorder=0)
227
+
228
+ # Network nodes — size proportional to importance
229
+ imp_norm = (net_imp - net_imp.min()) / (net_imp.max() - net_imp.min() + 1e-9)
230
+ for k, (name, color) in enumerate(zip(_NET_NAMES, _NET_COLORS)):
231
+ x, y, z = _NET_MNI[k]
232
+ size = 60 + imp_norm[k] * 260
233
+ ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
234
+ edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
235
+ ax3.text(x, y, z + 7, name, fontsize=5.5, color=color,
236
+ ha="center", va="bottom", fontweight="600", zorder=6)
237
+
238
+ # Draw top-5 inter-network connections as lines, thickness ∝ saliency
239
+ sal_vals = [s for s, _, _ in edge_scores[:5]]
240
+ sal_min, sal_max = min(sal_vals), max(sal_vals) + 1e-9
241
+ for rank, (score, ni, nj) in enumerate(edge_scores[:5]):
242
+ p1, p2 = _NET_MNI[ni], _NET_MNI[nj]
243
+ lw = 0.8 + 2.5 * (score - sal_min) / (sal_max - sal_min)
244
+ alph = 0.5 + 0.45 * (score - sal_min) / (sal_max - sal_min)
245
+ clr = "#fb923c" if rank == 0 else "#f4f4f5"
246
+ ax3.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]],
247
+ color=clr, linewidth=lw, alpha=alph, zorder=4)
248
+
249
+ ax3.view_init(elev=22, azim=-65)
250
+ ax3.set_box_aspect([1.2, 1.4, 1.0])
251
+
252
  fig.suptitle(
253
  f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
254
  color="#444", fontsize=8.5, y=1.02,