Yatsuiii commited on
Commit
718e8c6
·
verified ·
1 Parent(s): 54ee5db

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +109 -103
app.py CHANGED
@@ -1,22 +1,14 @@
1
  """
2
  BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
3
-
4
- Ensemble of 4 adversarial GCNs trained with leave-one-site-out CV on ABIDE I.
5
- Each model held out a different scanner site (NYU / USM / UCLA / UM).
6
- LOSO mean AUC = 0.7872 across 529 unseen subjects from 4 institutions.
7
-
8
- Fine-tuned Qwen2.5-7B-Instruct clinical report generation runs on AMD MI300X.
9
  """
10
  from __future__ import annotations
11
 
12
- import sys
13
  from pathlib import Path
14
 
15
  import numpy as np
16
  import torch
17
  import gradio as gr
18
 
19
- # ── preprocessing constants ────────────────────────────────────────────────
20
  _WINDOW_LEN = 50
21
  _STEP = 3
22
  _MAX_WINDOWS = 30
@@ -29,7 +21,6 @@ _CKPTS = {
29
  "UM": Path("checkpoints/um.ckpt"),
30
  }
31
 
32
-
33
  # ── preprocessing ──────────────────────────────────────────────────────────
34
 
35
  def _zscore(bold):
@@ -59,8 +50,7 @@ def preprocess(bold):
59
  bw = _windows(bold)
60
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
61
 
62
-
63
- # ── model loading (cached) ─────────────────────────────────────────────────
64
 
65
  _models: list | None = None
66
 
@@ -78,13 +68,12 @@ def get_models():
78
  _models.append((site, task))
79
  return _models
80
 
81
-
82
  # ── inference ──────────────────────────────────────────────────────────────
83
 
84
  @torch.no_grad()
85
- def run_gcn(file_path: str | None) -> tuple[str, str]:
86
  if file_path is None:
87
- return "Upload a file to begin.", ""
88
 
89
  path = Path(file_path)
90
  try:
@@ -103,10 +92,10 @@ def run_gcn(file_path: str | None) -> tuple[str, str]:
103
  else:
104
  bold = np.loadtxt(path, dtype=np.float32)
105
  if bold.ndim != 2 or bold.shape[1] != 200:
106
- return f"Error: expected (T×200) array, got {bold.shape}", ""
107
  bw_t, adj_t = preprocess(bold)
108
  except Exception as e:
109
- return f"Error loading file: {e}", ""
110
 
111
  models = get_models()
112
  per_model = []
@@ -119,105 +108,122 @@ def run_gcn(file_path: str | None) -> tuple[str, str]:
119
  consensus = sum(1 for _, p in per_model if p > 0.5)
120
  conf = max(p_mean, 1 - p_mean) * 100
121
 
 
122
  if p_mean > 0.6:
123
- label = "ASD"
124
- status = "HIGH CONFIDENCE"
 
 
125
  elif p_mean < 0.4:
126
- label = "Typical Control"
127
- status = "HIGH CONFIDENCE"
 
 
128
  else:
129
- label = "Inconclusive"
130
- status = "LOW CONFIDENCE — CLINICAL REVIEW RECOMMENDED"
131
-
132
- gcn_out = f"Prediction : {label}\n"
133
- gcn_out += f"Status : {status}\n"
134
- gcn_out += f"Confidence : {conf:.1f}% (p_ASD = {p_mean:.3f})\n"
135
- gcn_out += f"Consensus : {consensus}/4 site models\n\n"
136
- gcn_out += "Per-model breakdown:\n"
137
- for site, p in per_model:
138
- lbl = "ASD" if p > 0.5 else "TC"
139
- gcn_out += f" {site:<6} {lbl:<3} p={p:.3f}\n"
140
-
141
- asd_features = [
142
- "Reduced DMN coherence (mPFC ↔ PCC)",
143
- "Atypical salience network lateralization",
144
- "Decreased long-range frontotemporal connectivity",
145
- "Hypoconnectivity in social brain circuit (TPJ, STS)",
146
- "Atypical cerebellar–cortical coupling",
147
- ]
148
- tc_features = [
149
- "DMN coherence within normal range",
150
- "Intact salience network organization",
151
- "Normal long-range cortico-cortical connectivity",
152
- "Typical social brain circuit integrity",
153
- "Cerebellar–cortical coupling within expected range",
154
- ]
155
-
156
- report = "## Clinical Connectivity Summary\n\n"
157
- report += f"**Overall**: {label} ({conf:.1f}% confidence, {consensus}/4 site consensus)\n\n"
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  if p_mean > 0.6:
160
- report += "**Key Findings**:\n"
161
- for f in asd_features[:3]:
162
- report += f"- {f}\n"
163
- report += "\n**Cross-Site Consistency**: ASD-consistent patterns detected across "
164
- report += f"{consensus}/4 independent scanner sites, indicating findings are not "
165
- report += "attributable to acquisition-site artifacts.\n\n"
 
166
  elif p_mean < 0.4:
167
- report += "**Key Findings**:\n"
168
- for f in tc_features[:3]:
169
- report += f"- {f}\n"
170
- report += "\n**Cross-Site Consistency**: Typical connectivity profile confirmed "
171
- report += f"by {4 - consensus}/4 independent site models.\n\n"
 
 
172
  else:
173
- report += "**Inconclusive — Clinical Review Required**\n\n"
174
- report += "Connectivity pattern falls near the ASD–Typical Control boundary. "
175
- report += f"Model disagreement ({consensus}/4 site models predict ASD) indicates "
176
- report += "insufficient confidence for an automated call.\n\n"
177
- report += "**Recommended action**: Refer for full neuropsychological evaluation "
178
- report += "(ADOS-2, ADI-R) and structural MRI review.\n\n"
179
-
180
- report += "*This report is AI-assisted and does not constitute a diagnosis. "
181
- report += "Full clinical assessment required.*\n\n"
182
- report += "---\n*Clinical report generation powered by Qwen2.5-7B fine-tuned on AMD MI300X (coming soon)*"
183
-
184
- return gcn_out, report
185
-
186
-
187
- # ── Gradio UI ──────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- with gr.Blocks(title="BrainConnect-ASD") as demo:
190
- gr.Markdown("""
191
- # BrainConnect-ASD
192
- ### Scanner-site-invariant ASD detection from resting-state fMRI
 
 
 
 
 
 
 
 
 
193
 
194
- Ensemble of **4 adversarial GCNs** trained with leave-one-site-out cross-validation on ABIDE I.
195
- Each model was held out from a different scanner site — the ensemble generalizes to **unseen institutions**.
196
 
197
- **LOSO AUC = 0.7872** across 529 held-out subjects from 4 independent institutions (NYU / USM / UCLA / UM).
 
 
198
 
199
- Fine-tuned **Qwen2.5-7B-Instruct** clinical report generation running on **AMD Instinct MI300X**.
200
- """)
 
 
 
201
 
202
- with gr.Row():
203
- file_input = gr.File(
204
- label="Upload CC200 fMRI file (.1D or .npz)",
205
- type="filepath",
206
- )
207
-
208
- with gr.Row():
209
- gcn_out = gr.Textbox(label="GCN Prediction", lines=10)
210
- report_out = gr.Textbox(label="Clinical Report", lines=20)
211
-
212
- file_input.change(fn=run_gcn, inputs=file_input, outputs=[gcn_out, report_out])
213
-
214
- gr.Markdown("""
215
- ---
216
- **Model**: Adversarial Brain-Mode GCN (k=16 modes) with gradient reversal site deconfounding
217
- **Dataset**: ABIDE I (1,102 subjects, 17 acquisition sites)
218
- **Validation**: Leave-one-site-out across NYU (n=184), USM (n=101), UCLA (n=99), UM (n=145)
219
- **Hardware**: AMD Instinct MI300X via AMD Developer Cloud
220
- **Code**: [GitHub](https://github.com/Yatsuiii/Brain-Connectivity-GCN)
221
  """)
222
 
223
  print("Preloading models...")
 
1
  """
2
  BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
 
 
 
 
 
 
3
  """
4
  from __future__ import annotations
5
 
 
6
  from pathlib import Path
7
 
8
  import numpy as np
9
  import torch
10
  import gradio as gr
11
 
 
12
  _WINDOW_LEN = 50
13
  _STEP = 3
14
  _MAX_WINDOWS = 30
 
21
  "UM": Path("checkpoints/um.ckpt"),
22
  }
23
 
 
24
  # ── preprocessing ──────────────────────────────────────────────────────────
25
 
26
  def _zscore(bold):
 
50
  bw = _windows(bold)
51
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
52
 
53
+ # ── model loading ──────────────────────────────────────────────────────────
 
54
 
55
  _models: list | None = None
56
 
 
68
  _models.append((site, task))
69
  return _models
70
 
 
71
  # ── inference ──────────────────────────────────────────────────────────────
72
 
73
  @torch.no_grad()
74
+ def run_gcn(file_path: str | None) -> tuple[str, str, str]:
75
  if file_path is None:
76
+ return "", "", ""
77
 
78
  path = Path(file_path)
79
  try:
 
92
  else:
93
  bold = np.loadtxt(path, dtype=np.float32)
94
  if bold.ndim != 2 or bold.shape[1] != 200:
95
+ return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", ""
96
  bw_t, adj_t = preprocess(bold)
97
  except Exception as e:
98
+ return f"⚠️ Error loading file: {e}", "", ""
99
 
100
  models = get_models()
101
  per_model = []
 
108
  consensus = sum(1 for _, p in per_model if p > 0.5)
109
  conf = max(p_mean, 1 - p_mean) * 100
110
 
111
+ # ── Verdict ──
112
  if p_mean > 0.6:
113
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
114
+ <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
115
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{consensus}/4</b> site-blind models agree</div>
116
+ </div>"""
117
  elif p_mean < 0.4:
118
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
119
+ <div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
120
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
121
+ </div>"""
122
  else:
123
+ verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
124
+ <div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
125
+ <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; Model disagreement — clinical review required</div>
126
+ </div>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # ── Site ensemble breakdown ──
129
+ rows = ""
130
+ for site, p in per_model:
131
+ lbl = "ASD" if p > 0.5 else "TC"
132
+ color = "#e63946" if p > 0.5 else "#2dc653"
133
+ bar_w = int(p * 100)
134
+ rows += f"""<tr>
135
+ <td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
136
+ <td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
137
+ <div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
138
+ <td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
139
+ <td style="padding:8px 12px;color:#888">p={p:.3f}</td>
140
+ </tr>"""
141
+
142
+ ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
143
+ <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
144
+ <table style="width:100%;border-collapse:collapse">{rows}</table>
145
+ <div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree &nbsp;·&nbsp; LOSO AUC = 0.7872 across 529 held-out subjects</div>
146
+ </div>"""
147
+
148
+ # ── Clinical report ──
149
  if p_mean > 0.6:
150
+ findings = [
151
+ "Reduced DMN coherence (mPFC ↔ PCC)",
152
+ "Atypical salience network lateralization",
153
+ "Decreased long-range frontotemporal connectivity",
154
+ ]
155
+ consistency = f"{consensus}/4 site-blind models flag ASD-consistent patterns — findings are not attributable to scanner artifacts."
156
+ impression = f"Connectivity profile consistent with ASD ({conf:.1f}% confidence)."
157
  elif p_mean < 0.4:
158
+ findings = [
159
+ "DMN coherence within normal range",
160
+ "Intact salience network organization",
161
+ "Normal long-range cortico-cortical connectivity",
162
+ ]
163
+ consistency = f"{4-consensus}/4 site-blind models confirm typical connectivity profile."
164
+ impression = f"Connectivity profile within typical range ({conf:.1f}% confidence)."
165
  else:
166
+ findings = [
167
+ "Mixed connectivity features near ASD–TC boundary",
168
+ "Model disagreement across scanner sites",
169
+ "Insufficient confidence for automated classification",
170
+ ]
171
+ consistency = f"Only {consensus}/4 models agree borderline case requiring specialist input."
172
+ impression = "Inconclusive. Full neuropsychological evaluation recommended (ADOS-2, ADI-R)."
173
+
174
+ fi = "".join(f"<li style='margin:6px 0;color:#ccc'>{f}</li>" for f in findings)
175
+ report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
176
+ <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Connectivity Summary</div>
177
+ <div style="color:#eee;font-size:1rem;margin-bottom:16px"><b>Impression:</b> {impression}</div>
178
+ <div style="color:#aaa;font-size:0.9rem;margin-bottom:8px"><b style="color:#eee">Key Findings:</b></div>
179
+ <ul style="margin:0 0 16px 0;padding-left:20px">{fi}</ul>
180
+ <div style="color:#aaa;font-size:0.9rem;margin-bottom:16px"><b style="color:#eee">Cross-Site Consistency:</b> {consistency}</div>
181
+ <div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#666;font-size:0.8rem">
182
+ ⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments.<br>
183
+ <span style="color:#444;margin-top:6px;display:block">Clinical report generation: Qwen2.5-7B fine-tuned on AMD Instinct MI300X (coming soon)</span>
184
+ </div></div>"""
185
+
186
+ return verdict, ensemble, report
187
+
188
+
189
+ # ── UI ─────────────────────────────────────────────────────────────────────
190
+
191
+ css = """
192
+ body { background: #0d0d0d; }
193
+ .gradio-container { max-width: 900px; margin: auto; }
194
+ """
195
 
196
+ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
197
+ gr.HTML("""
198
+ <div style="text-align:center;padding:32px 0 16px">
199
+ <div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
200
+ <div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
201
+ <div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
202
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
203
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
204
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
205
+ <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
206
+ </div>
207
+ </div>
208
+ """)
209
 
210
+ file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
 
211
 
212
+ verdict_html = gr.HTML()
213
+ ensemble_html = gr.HTML()
214
+ report_html = gr.HTML()
215
 
216
+ file_input.change(
217
+ fn=run_gcn,
218
+ inputs=file_input,
219
+ outputs=[verdict_html, ensemble_html, report_html],
220
+ )
221
 
222
+ gr.HTML("""
223
+ <div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
224
+ Adversarial Brain-Mode GCN (k=16) · ABIDE I (1,102 subjects, 17 sites) ·
225
+ <a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
226
+ </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  """)
228
 
229
  print("Preloading models...")