Yatsuiii commited on
Commit
6526502
·
verified ·
1 Parent(s): 718e8c6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +114 -16
app.py CHANGED
@@ -3,6 +3,7 @@ BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
3
  """
4
  from __future__ import annotations
5
 
 
6
  from pathlib import Path
7
 
8
  import numpy as np
@@ -68,12 +69,92 @@ def get_models():
68
  _models.append((site, task))
69
  return _models
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # ── inference ──────────────────────────────────────────────────────────────
72
 
73
- @torch.no_grad()
74
- def run_gcn(file_path: str | None) -> tuple[str, str, str]:
75
  if file_path is None:
76
- return "", "", ""
77
 
78
  path = Path(file_path)
79
  try:
@@ -92,23 +173,33 @@ def run_gcn(file_path: str | None) -> tuple[str, str, str]:
92
  else:
93
  bold = np.loadtxt(path, dtype=np.float32)
94
  if bold.ndim != 2 or bold.shape[1] != 200:
95
- return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", ""
96
  bw_t, adj_t = preprocess(bold)
97
  except Exception as e:
98
- return f"⚠️ Error loading file: {e}", "", ""
99
 
100
  models = get_models()
 
 
101
  per_model = []
102
- for site, task in models:
103
- logits = task(bw_t, adj_t)
104
- p = torch.softmax(logits, -1)[0, 1].item()
105
- per_model.append((site, p))
 
106
 
107
  p_mean = float(np.mean([p for _, p in per_model]))
108
  consensus = sum(1 for _, p in per_model if p > 0.5)
109
  conf = max(p_mean, 1 - p_mean) * 100
110
 
111
- # ── Verdict ──
 
 
 
 
 
 
 
112
  if p_mean > 0.6:
113
  verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
114
  <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
@@ -183,14 +274,14 @@ def run_gcn(file_path: str | None) -> tuple[str, str, str]:
183
  <span style="color:#444;margin-top:6px;display:block">Clinical report generation: Qwen2.5-7B fine-tuned on AMD Instinct MI300X (coming soon)</span>
184
  </div></div>"""
185
 
186
- return verdict, ensemble, report
187
 
188
 
189
  # ── UI ─────────────────────────────────────────────────────────────────────
190
 
191
  css = """
192
  body { background: #0d0d0d; }
193
- .gradio-container { max-width: 900px; margin: auto; }
194
  """
195
 
196
  with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
@@ -209,14 +300,21 @@ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as dem
209
 
210
  file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
211
 
212
- verdict_html = gr.HTML()
213
- ensemble_html = gr.HTML()
214
- report_html = gr.HTML()
 
 
 
 
 
 
 
215
 
216
  file_input.change(
217
  fn=run_gcn,
218
  inputs=file_input,
219
- outputs=[verdict_html, ensemble_html, report_html],
220
  )
221
 
222
  gr.HTML("""
 
3
  """
4
  from __future__ import annotations
5
 
6
+ import io
7
  from pathlib import Path
8
 
9
  import numpy as np
 
69
  _models.append((site, task))
70
  return _models
71
 
72
+ # ── gradient saliency ──────────────────────────────────────────────────────
73
+
74
+ def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
75
+ """Gradient of p(ASD) w.r.t. adjacency matrix, averaged over ensemble."""
76
+ maps = []
77
+ for _, task in models:
78
+ adj = adj_t.clone().requires_grad_(True)
79
+ logits = task.model(bw_t, adj)
80
+ p = torch.softmax(logits, -1)[0, 1]
81
+ p.backward()
82
+ maps.append(adj.grad[0].abs().detach().numpy())
83
+ sal = np.mean(maps, axis=0) # (200, 200)
84
+ sal = (sal + sal.T) / 2 # symmetrize
85
+ return sal
86
+
87
+
88
+ def _saliency_figure(sal: np.ndarray, p_mean: float):
89
+ import matplotlib
90
+ matplotlib.use("Agg")
91
+ import matplotlib.pyplot as plt
92
+ from PIL import Image
93
+
94
+ thresh = np.percentile(sal, 95)
95
+ sal_top = np.where(sal >= thresh, sal, 0.0)
96
+
97
+ roi_imp = sal.sum(1)
98
+ top20 = roi_imp.argsort()[-20:][::-1]
99
+
100
+ verdict_color = (
101
+ "#e63946" if p_mean > 0.6 else
102
+ "#2dc653" if p_mean < 0.4 else
103
+ "#f4a261"
104
+ )
105
+
106
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
107
+ fig.patch.set_facecolor("#0d0d0d")
108
+
109
+ # ── Left: FC edge saliency heatmap ──
110
+ ax = axes[0]
111
+ ax.set_facecolor("#111")
112
+ im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
113
+ ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
114
+ ax.set_xlabel("ROI index", color="#777", fontsize=9)
115
+ ax.set_ylabel("ROI index", color="#777", fontsize=9)
116
+ ax.tick_params(colors="#555", labelsize=8)
117
+ cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
118
+ cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
119
+ plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
120
+ for spine in ax.spines.values():
121
+ spine.set_color("#333")
122
+
123
+ # ── Right: top-20 ROI importance bar chart ──
124
+ ax2 = axes[1]
125
+ ax2.set_facecolor("#111")
126
+ ax2.barh(
127
+ range(20), roi_imp[top20],
128
+ color=verdict_color, alpha=0.75, edgecolor="none",
129
+ )
130
+ ax2.set_yticks(range(20))
131
+ ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
132
+ ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
133
+ ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
134
+ ax2.tick_params(colors="#555", labelsize=8)
135
+ ax2.invert_yaxis()
136
+ for spine in ["top", "right"]:
137
+ ax2.spines[spine].set_visible(False)
138
+ for spine in ["bottom", "left"]:
139
+ ax2.spines[spine].set_color("#333")
140
+
141
+ fig.suptitle(
142
+ f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
143
+ color="#888", fontsize=10, y=1.02,
144
+ )
145
+
146
+ plt.tight_layout()
147
+ buf = io.BytesIO()
148
+ plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
149
+ plt.close(fig)
150
+ buf.seek(0)
151
+ return Image.open(buf).copy()
152
+
153
  # ── inference ──────────────────────────────────────────────────────────────
154
 
155
+ def run_gcn(file_path: str | None):
 
156
  if file_path is None:
157
+ return "", "", "", None
158
 
159
  path = Path(file_path)
160
  try:
 
173
  else:
174
  bold = np.loadtxt(path, dtype=np.float32)
175
  if bold.ndim != 2 or bold.shape[1] != 200:
176
+ return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None
177
  bw_t, adj_t = preprocess(bold)
178
  except Exception as e:
179
+ return f"⚠️ Error loading file: {e}", "", "", None
180
 
181
  models = get_models()
182
+
183
+ # ── Ensemble inference (no grad) ──
184
  per_model = []
185
+ with torch.no_grad():
186
+ for site, task in models:
187
+ logits = task(bw_t, adj_t)
188
+ p = torch.softmax(logits, -1)[0, 1].item()
189
+ per_model.append((site, p))
190
 
191
  p_mean = float(np.mean([p for _, p in per_model]))
192
  consensus = sum(1 for _, p in per_model if p > 0.5)
193
  conf = max(p_mean, 1 - p_mean) * 100
194
 
195
+ # ── Gradient saliency ──
196
+ try:
197
+ sal = _compute_saliency(bw_t, adj_t, models)
198
+ sal_img = _saliency_figure(sal, p_mean)
199
+ except Exception:
200
+ sal_img = None
201
+
202
+ # ── Verdict card ──
203
  if p_mean > 0.6:
204
  verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
205
  <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
 
274
  <span style="color:#444;margin-top:6px;display:block">Clinical report generation: Qwen2.5-7B fine-tuned on AMD Instinct MI300X (coming soon)</span>
275
  </div></div>"""
276
 
277
+ return verdict, ensemble, report, sal_img
278
 
279
 
280
  # ── UI ─────────────────────────────────────────────────────────────────────
281
 
282
  css = """
283
  body { background: #0d0d0d; }
284
+ .gradio-container { max-width: 960px; margin: auto; }
285
  """
286
 
287
  with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
 
300
 
301
  file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
302
 
303
+ verdict_html = gr.HTML()
304
+ ensemble_html = gr.HTML()
305
+
306
+ with gr.Row():
307
+ report_html = gr.HTML()
308
+
309
+ gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
310
+ saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
311
+
312
+ report_html2 = gr.HTML()
313
 
314
  file_input.change(
315
  fn=run_gcn,
316
  inputs=file_input,
317
+ outputs=[verdict_html, ensemble_html, report_html2, saliency_img],
318
  )
319
 
320
  gr.HTML("""