Yatsuiii commited on
Commit
adc564a
·
verified ·
1 Parent(s): 4eaa291

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +78 -23
app.py CHANGED
@@ -17,6 +17,11 @@ _STEP = 3
17
  _MAX_WINDOWS = 30
18
  _FC_THRESHOLD = 0.2
19
 
 
 
 
 
 
20
  _CKPTS = {
21
  "NYU": Path("checkpoints/nyu.ckpt"),
22
  "USM": Path("checkpoints/usm.ckpt"),
@@ -89,44 +94,86 @@ def _saliency_figure(sal, p_mean):
89
  import matplotlib.pyplot as plt
90
  from PIL import Image
91
 
92
- thresh = np.percentile(sal, 95)
93
- sal_top = np.where(sal >= thresh, sal, 0.0)
94
- roi_imp = sal.sum(1)
95
- top20 = roi_imp.argsort()[-20:][::-1]
96
- color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261"
 
 
 
 
 
 
 
 
97
 
98
- fig, axes = plt.subplots(1, 2, figsize=(14, 5))
99
  fig.patch.set_facecolor("#0d0d0d")
100
 
 
101
  ax = axes[0]
102
- ax.set_facecolor("#111"); ax.tick_params(colors="#555", labelsize=8)
103
- for sp in ax.spines.values(): sp.set_color("#222")
104
- im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
105
- ax.set_title("FC Edge Saliency (top 5% connections)", color="#bbb", fontsize=10, pad=10)
106
- ax.set_xlabel("ROI index", color="#555", fontsize=9)
107
- ax.set_ylabel("ROI index", color="#555", fontsize=9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
109
  cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
110
  plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
 
111
 
 
112
  ax2 = axes[1]
113
- ax2.set_facecolor("#111"); ax2.tick_params(colors="#555", labelsize=8)
114
- ax2.barh(range(20), roi_imp[top20], color=color, alpha=0.8, edgecolor="none")
115
- ax2.set_yticks(range(20))
116
- ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#aaa")
117
- ax2.set_xlabel("Cumulative gradient magnitude", color="#555", fontsize=9)
118
- ax2.set_title("Top-20 ROIs by Prediction Influence", color="#bbb", fontsize=10, pad=10)
 
 
 
 
119
  ax2.invert_yaxis()
120
- for sp in ["top", "right"]: ax2.spines[sp].set_visible(False)
121
- for sp in ["bottom", "left"]: ax2.spines[sp].set_color("#222")
 
 
 
 
 
 
 
 
122
 
123
  fig.suptitle(
124
- f"Gradient Saliency · p(ASD)={p_mean:.3f} · {len(_models)}-model LOSO ensemble",
125
- color="#555", fontsize=9, y=1.01,
126
  )
127
  plt.tight_layout()
128
  buf = io.BytesIO()
129
- plt.savefig(buf, format="png", dpi=130, bbox_inches="tight", facecolor="#0d0d0d")
130
  plt.close(fig)
131
  buf.seek(0)
132
  return Image.open(buf).copy()
@@ -260,6 +307,14 @@ HEADER = """
260
  <div style="color:#333;font-size:0.72rem;letter-spacing:4px;text-transform:uppercase;margin-top:10px">
261
  Clinical AI · Resting-state fMRI · Scanner-Site-Invariant Classification
262
  </div>
 
 
 
 
 
 
 
 
263
  <div style="display:flex;gap:0;margin-top:28px;border:1px solid #1a1a1a;border-radius:12px;overflow:hidden;max-width:700px">
264
  <div style="padding:20px 32px;flex:1;border-right:1px solid #1a1a1a;min-width:120px">
265
  <div style="font-size:2.2rem;font-weight:900;color:#e63946;line-height:1">0.7872</div>
 
17
  _MAX_WINDOWS = 30
18
  _FC_THRESHOLD = 0.2
19
 
20
+ # CC200 atlas (Craddock 2012) → approximate Yeo 7-network parcellation
21
+ _NET_NAMES = ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"]
22
+ _NET_BOUNDS = [0, 38, 69, 99, 137, 165, 180, 200]
23
+ _NET_COLORS = ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"]
24
+
25
  _CKPTS = {
26
  "NYU": Path("checkpoints/nyu.ckpt"),
27
  "USM": Path("checkpoints/usm.ckpt"),
 
94
  import matplotlib.pyplot as plt
95
  from PIL import Image
96
 
97
+ n_nets = len(_NET_NAMES)
98
+
99
+ # Aggregate 200×200 saliency → 7×7 network-level matrix
100
+ net_sal = np.zeros((n_nets, n_nets))
101
+ for i, (s1, e1) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
102
+ for j, (s2, e2) in enumerate(zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])):
103
+ net_sal[i, j] = sal[s1:e1, s2:e2].mean()
104
+
105
+ # Network importance: mean outgoing + incoming saliency per network
106
+ net_imp = np.array([
107
+ sal[s:e, :].mean() + sal[:, s:e].mean()
108
+ for s, e in zip(_NET_BOUNDS[:-1], _NET_BOUNDS[1:])
109
+ ])
110
 
111
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
112
  fig.patch.set_facecolor("#0d0d0d")
113
 
114
+ # ── Left: 7×7 network heatmap ──────────────────────────────────────────
115
  ax = axes[0]
116
+ ax.set_facecolor("#111")
117
+ im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
118
+ ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=11, pad=14, fontweight="bold")
119
+
120
+ ax.set_xticks(range(n_nets))
121
+ ax.set_yticks(range(n_nets))
122
+ ax.set_xticklabels(_NET_NAMES, rotation=40, ha="right", fontsize=9, color="#ccc")
123
+ ax.set_yticklabels(_NET_NAMES, fontsize=9, color="#ccc")
124
+ ax.tick_params(colors="#555", length=0)
125
+ for sp in ax.spines.values():
126
+ sp.set_color("#222")
127
+
128
+ # Boundary lines between networks
129
+ for k in range(1, n_nets):
130
+ ax.axhline(k - 0.5, color="#2a2a2a", lw=1.0)
131
+ ax.axvline(k - 0.5, color="#2a2a2a", lw=1.0)
132
+
133
+ # Annotate each cell with its value
134
+ vmax = net_sal.max()
135
+ for i in range(n_nets):
136
+ for j in range(n_nets):
137
+ txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#555"
138
+ ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
139
+ fontsize=6.5, color=txt_color)
140
+
141
  cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
142
  cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
143
  plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
144
+ cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=7.5)
145
 
146
+ # ── Right: network importance bar chart ────────────────────────────────
147
  ax2 = axes[1]
148
+ ax2.set_facecolor("#111")
149
+ ax2.tick_params(colors="#555", labelsize=9)
150
+
151
+ order = net_imp.argsort()[::-1]
152
+ bars = ax2.barh(range(n_nets), net_imp[order],
153
+ color=[_NET_COLORS[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
154
+ ax2.set_yticks(range(n_nets))
155
+ ax2.set_yticklabels([_NET_NAMES[i] for i in order], fontsize=9.5, color="#ddd")
156
+ ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
157
+ ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
158
  ax2.invert_yaxis()
159
+ for sp in ["top", "right"]:
160
+ ax2.spines[sp].set_visible(False)
161
+ for sp in ["bottom", "left"]:
162
+ ax2.spines[sp].set_color("#222")
163
+
164
+ # Value labels on bars
165
+ x_max = net_imp.max()
166
+ for bar, val in zip(bars, net_imp[order]):
167
+ ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
168
+ f"{val:.4f}", va="center", color="#555", fontsize=7.5)
169
 
170
  fig.suptitle(
171
+ f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(_models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
172
+ color="#444", fontsize=8.5, y=1.02,
173
  )
174
  plt.tight_layout()
175
  buf = io.BytesIO()
176
+ plt.savefig(buf, format="png", dpi=140, bbox_inches="tight", facecolor="#0d0d0d")
177
  plt.close(fig)
178
  buf.seek(0)
179
  return Image.open(buf).copy()
 
307
  <div style="color:#333;font-size:0.72rem;letter-spacing:4px;text-transform:uppercase;margin-top:10px">
308
  Clinical AI · Resting-state fMRI · Scanner-Site-Invariant Classification
309
  </div>
310
+ <div style="color:#444;font-size:0.93rem;margin-top:18px;max-width:720px;line-height:1.75">
311
+ 1 in 44 children is diagnosed with ASD. Today, diagnosis takes years of behavioral observation —
312
+ no biomarker exists. We trained a scanner-site-invariant GCN on 1,102 subjects across 17 institutions
313
+ and validated it on <span style="color:#e63946;font-weight:700">529 subjects the model never saw, from sites it was never trained on</span>.
314
+ The result: <span style="color:#e63946;font-weight:700">AUC 0.7872</span> — not on held-out splits of the same scanner, but
315
+ across entirely different hospitals. Fine-tuned <span style="color:#f4a261;font-weight:700">Qwen2.5-7B on AMD MI300X</span>
316
+ then translates raw connectivity patterns into structured clinical language a clinician can act on.
317
+ </div>
318
  <div style="display:flex;gap:0;margin-top:28px;border:1px solid #1a1a1a;border-radius:12px;overflow:hidden;max-width:700px">
319
  <div style="padding:20px 32px;flex:1;border-right:1px solid #1a1a1a;min-width:120px">
320
  <div style="font-size:2.2rem;font-weight:900;color:#e63946;line-height:1">0.7872</div>