Yatsuiii commited on
Commit
1f4f845
·
verified ·
1 Parent(s): ff6bc7a

Delete app_with_llm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app_with_llm.py +0 -363
app_with_llm.py DELETED
@@ -1,363 +0,0 @@
1
- """
2
- BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
3
- Full pipeline: Adversarial GCN + Qwen2.5-7B fine-tuned on AMD MI300X.
4
- """
5
- from __future__ import annotations
6
-
7
- import io
8
- from pathlib import Path
9
-
10
- import numpy as np
11
- import torch
12
- import gradio as gr
13
-
14
- _WINDOW_LEN = 50
15
- _STEP = 3
16
- _MAX_WINDOWS = 30
17
- _FC_THRESHOLD = 0.2
18
-
19
- _CKPTS = {
20
- "NYU": Path("checkpoints/nyu.ckpt"),
21
- "USM": Path("checkpoints/usm.ckpt"),
22
- "UCLA": Path("checkpoints/ucla.ckpt"),
23
- "UM": Path("checkpoints/um.ckpt"),
24
- }
25
-
26
- _LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
27
-
28
- SYSTEM_PROMPT = (
29
- "You are a clinical AI assistant specializing in functional MRI brain "
30
- "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
31
- "You interpret outputs from a validated graph neural network (GCN) trained on "
32
- "the ABIDE I dataset and provide structured clinical summaries for neurologists "
33
- "and psychiatrists. Your reports are informative and evidence-based but always "
34
- "clarify that findings are AI-assisted and should be integrated with full "
35
- "clinical assessment. You do not make a diagnosis."
36
- )
37
-
38
- # ── preprocessing ──────────────────────────────────────────────────────────
39
-
40
- def _zscore(bold):
41
- mean = bold.mean(0, keepdims=True)
42
- std = bold.std(0, keepdims=True)
43
- std[std < 1e-8] = 1.0
44
- return ((bold - mean) / std).astype(np.float32)
45
-
46
- def _fc(bold):
47
- fc = np.corrcoef(bold.T).astype(np.float32)
48
- np.nan_to_num(fc, copy=False)
49
- return fc
50
-
51
- def _windows(bold):
52
- T, N = bold.shape
53
- starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
54
- w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
55
- if len(w) >= _MAX_WINDOWS:
56
- return w[:_MAX_WINDOWS]
57
- return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
58
-
59
- def preprocess(bold):
60
- bold = _zscore(bold)
61
- fc = _fc(bold)
62
- fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
63
- adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
64
- bw = _windows(bold)
65
- return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
66
-
67
- # ── GCN model loading ──────────────────────────────────────────────────────
68
-
69
- _models: list | None = None
70
-
71
- def get_models():
72
- global _models
73
- if _models is not None:
74
- return _models
75
- from brain_gcn.tasks import ClassificationTask
76
- _models = []
77
- for site, ckpt in _CKPTS.items():
78
- if not ckpt.exists():
79
- continue
80
- task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
81
- task.eval()
82
- _models.append((site, task))
83
- return _models
84
-
85
- # ── LLM loading ────────────────────────────────────────────────────────────
86
-
87
- _llm = None
88
-
89
- def get_llm():
90
- global _llm
91
- if _llm is not None:
92
- return _llm
93
- from transformers import AutoModelForCausalLM, AutoTokenizer
94
- print(f"Loading LLM: {_LLM_MODEL}")
95
- tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
96
- tok.pad_token = tok.eos_token
97
- mdl = AutoModelForCausalLM.from_pretrained(
98
- _LLM_MODEL,
99
- torch_dtype=torch.bfloat16,
100
- device_map="auto",
101
- )
102
- mdl.eval()
103
- _llm = (mdl, tok)
104
- return _llm
105
-
106
-
107
- def _llm_report(p_mean: float, per_model: list) -> str:
108
- consensus = sum(1 for _, p in per_model if p > 0.5)
109
- per_model_str = "\n".join(
110
- f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
111
- )
112
- conf_label = (
113
- "HIGH" if p_mean >= 0.75 else
114
- "MODERATE" if p_mean >= 0.6 else
115
- "LOW / UNCERTAIN" if p_mean >= 0.4 else
116
- "MODERATE (TC)" if p_mean >= 0.25 else
117
- "HIGH (TC)"
118
- )
119
- user_msg = (
120
- f"Brain Connectivity GCN Analysis Report\n"
121
- f"{'='*40}\n"
122
- f"p(ASD) : {p_mean:.3f}\n"
123
- f"Confidence Level : {conf_label}\n"
124
- f"Model Consensus : {consensus}/4 site-blind models predict ASD\n\n"
125
- f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
126
- f"Please provide a structured clinical interpretation of these findings."
127
- )
128
- try:
129
- mdl, tok = get_llm()
130
- messages = [
131
- {"role": "system", "content": SYSTEM_PROMPT},
132
- {"role": "user", "content": user_msg},
133
- ]
134
- text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
135
- inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
136
- with torch.no_grad():
137
- out = mdl.generate(
138
- **inputs,
139
- max_new_tokens=512,
140
- temperature=0.3,
141
- do_sample=True,
142
- pad_token_id=tok.eos_token_id,
143
- )
144
- generated = out[0][inputs["input_ids"].shape[1]:]
145
- return tok.decode(generated, skip_special_tokens=True).strip()
146
- except Exception as e:
147
- return f"LLM unavailable: {e}"
148
-
149
- # ── gradient saliency ──────────────────────────────────────────────────────
150
-
151
- def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
152
- maps = []
153
- for _, task in models:
154
- adj = adj_t.clone().requires_grad_(True)
155
- logits = task.model(bw_t, adj)
156
- p = torch.softmax(logits, -1)[0, 1]
157
- p.backward()
158
- maps.append(adj.grad[0].abs().detach().numpy())
159
- sal = np.mean(maps, axis=0)
160
- sal = (sal + sal.T) / 2
161
- return sal
162
-
163
-
164
- def _saliency_figure(sal: np.ndarray, p_mean: float):
165
- import matplotlib
166
- matplotlib.use("Agg")
167
- import matplotlib.pyplot as plt
168
- from PIL import Image
169
-
170
- thresh = np.percentile(sal, 95)
171
- sal_top = np.where(sal >= thresh, sal, 0.0)
172
- roi_imp = sal.sum(1)
173
- top20 = roi_imp.argsort()[-20:][::-1]
174
- verdict_color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261"
175
-
176
- fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
177
- fig.patch.set_facecolor("#0d0d0d")
178
-
179
- ax = axes[0]
180
- ax.set_facecolor("#111")
181
- im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
182
- ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
183
- ax.set_xlabel("ROI index", color="#777", fontsize=9)
184
- ax.set_ylabel("ROI index", color="#777", fontsize=9)
185
- ax.tick_params(colors="#555", labelsize=8)
186
- cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
187
- cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
188
- plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
189
- for spine in ax.spines.values():
190
- spine.set_color("#333")
191
-
192
- ax2 = axes[1]
193
- ax2.set_facecolor("#111")
194
- ax2.barh(range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none")
195
- ax2.set_yticks(range(20))
196
- ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
197
- ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
198
- ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
199
- ax2.tick_params(colors="#555", labelsize=8)
200
- ax2.invert_yaxis()
201
- for spine in ["top", "right"]:
202
- ax2.spines[spine].set_visible(False)
203
- for spine in ["bottom", "left"]:
204
- ax2.spines[spine].set_color("#333")
205
-
206
- fig.suptitle(
207
- f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
208
- color="#888", fontsize=10, y=1.02,
209
- )
210
- plt.tight_layout()
211
- buf = io.BytesIO()
212
- plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
213
- plt.close(fig)
214
- buf.seek(0)
215
- return Image.open(buf).copy()
216
-
217
- # ── inference ──────────────────────────────────────────────────────────────
218
-
219
- def run_gcn(file_path: str | None):
220
- if file_path is None:
221
- return "", "", "", None, ""
222
-
223
- path = Path(file_path)
224
- try:
225
- if path.suffix == ".npz":
226
- d = np.load(path, allow_pickle=True)
227
- fc = d["mean_fc"].astype(np.float32)
228
- fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
229
- adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
230
- bw = d["bold_windows"].astype(np.float32)
231
- if len(bw) >= _MAX_WINDOWS:
232
- bw = bw[:_MAX_WINDOWS]
233
- else:
234
- bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
235
- bw_t = torch.FloatTensor(bw).unsqueeze(0)
236
- adj_t = torch.FloatTensor(adj).unsqueeze(0)
237
- else:
238
- bold = np.loadtxt(path, dtype=np.float32)
239
- if bold.ndim != 2 or bold.shape[1] != 200:
240
- return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None, ""
241
- bw_t, adj_t = preprocess(bold)
242
- except Exception as e:
243
- return f"⚠️ Error loading file: {e}", "", "", None, ""
244
-
245
- models = get_models()
246
-
247
- per_model = []
248
- with torch.no_grad():
249
- for site, task in models:
250
- logits = task(bw_t, adj_t)
251
- p = torch.softmax(logits, -1)[0, 1].item()
252
- per_model.append((site, p))
253
-
254
- p_mean = float(np.mean([p for _, p in per_model]))
255
- consensus = sum(1 for _, p in per_model if p > 0.5)
256
- conf = max(p_mean, 1 - p_mean) * 100
257
-
258
- try:
259
- sal = _compute_saliency(bw_t, adj_t, models)
260
- sal_img = _saliency_figure(sal, p_mean)
261
- except Exception:
262
- sal_img = None
263
-
264
- # Verdict
265
- if p_mean > 0.6:
266
- verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
267
- <div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
268
- <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{consensus}/4</b> site-blind models agree</div>
269
- </div>"""
270
- elif p_mean < 0.4:
271
- verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
272
- <div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
273
- <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
274
- </div>"""
275
- else:
276
- verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
277
- <div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
278
- <div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> &nbsp;|&nbsp; p(ASD) = <b style="color:white">{p_mean:.3f}</b> &nbsp;|&nbsp; Model disagreement — clinical review required</div>
279
- </div>"""
280
-
281
- # Ensemble breakdown
282
- rows = ""
283
- for site, p in per_model:
284
- lbl = "ASD" if p > 0.5 else "TC"
285
- color = "#e63946" if p > 0.5 else "#2dc653"
286
- bar_w = int(p * 100)
287
- rows += f"""<tr>
288
- <td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
289
- <td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
290
- <div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
291
- <td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
292
- <td style="padding:8px 12px;color:#888">p={p:.3f}</td>
293
- </tr>"""
294
-
295
- ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
296
- <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
297
- <table style="width:100%;border-collapse:collapse">{rows}</table>
298
- <div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree &nbsp;·&nbsp; LOSO AUC = 0.7872 across 529 held-out subjects</div>
299
- </div>"""
300
-
301
- # LLM clinical report
302
- llm_text = _llm_report(p_mean, per_model)
303
- report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
304
- <div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Report — Qwen2.5-7B fine-tuned on AMD Instinct MI300X</div>
305
- <div style="color:#ddd;font-size:0.95rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
306
- <div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#555;font-size:0.78rem;margin-top:16px">
307
- ⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments (ADOS-2, ADI-R).
308
- </div></div>"""
309
-
310
- return verdict, ensemble, report, sal_img
311
-
312
-
313
- # ── UI ─────────────────────────────────────────────────────────────────────
314
-
315
- css = """
316
- body { background: #0d0d0d; }
317
- .gradio-container { max-width: 960px; margin: auto; }
318
- """
319
-
320
- with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
321
- gr.HTML("""
322
- <div style="text-align:center;padding:32px 0 16px">
323
- <div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
324
- <div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
325
- <div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
326
- <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
327
- <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
328
- <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
329
- <span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
330
- </div>
331
- </div>
332
- """)
333
-
334
- file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
335
- verdict_html = gr.HTML()
336
- ensemble_html = gr.HTML()
337
-
338
- gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
339
- saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
340
-
341
- report_html = gr.HTML()
342
-
343
- file_input.change(
344
- fn=run_gcn,
345
- inputs=file_input,
346
- outputs=[verdict_html, ensemble_html, report_html, saliency_img],
347
- )
348
-
349
- gr.HTML("""
350
- <div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
351
- Adversarial Brain-Mode GCN (k=16) · Qwen2.5-7B LoRA (AMD MI300X) · ABIDE I ·
352
- <a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
353
- </div>
354
- """)
355
-
356
- print("Preloading GCN models...")
357
- get_models()
358
- print("Preloading LLM...")
359
- get_llm()
360
- print("All models ready.")
361
-
362
- if __name__ == "__main__":
363
- demo.launch()