Yatsuiii commited on
Commit
da0634d
·
verified ·
1 Parent(s): 7d63408

5x LLM improvements: saliency grounding, anti-hallucination system prompt, temp 0.1, n_subjects, per-network scores in prompt

Browse files
Files changed (1) hide show
  1. app.py +45 -14
app.py CHANGED
@@ -141,10 +141,15 @@ _SYSTEM_PROMPT = (
141
  "You are a clinical AI assistant specializing in functional MRI brain "
142
  "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
143
  "You interpret outputs from a validated graph neural network (GCN) trained on "
144
- "the ABIDE I dataset and provide structured clinical summaries for neurologists "
145
- "and psychiatrists. Your reports are informative and evidence-based but always "
146
- "clarify that findings are AI-assisted and should be integrated with full "
147
- "clinical assessment. You do not make a diagnosis."
 
 
 
 
 
148
  )
149
  _llm_cache = None
150
 
@@ -162,7 +167,7 @@ def get_llm():
162
  _llm_cache = (mdl, tok)
163
  return _llm_cache
164
 
165
- def _llm_report(p_mean: float, per_model: list) -> str:
166
  consensus = sum(1 for _, p in per_model if p > 0.5)
167
  per_model_str = "\n".join(
168
  f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
@@ -173,13 +178,29 @@ def _llm_report(p_mean: float, per_model: list) -> str:
173
  "LOW / UNCERTAIN" if p_mean >= 0.4 else
174
  "MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
175
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  user_msg = (
177
  f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
178
- f"p(ASD) : {p_mean:.3f}\n"
179
- f"Confidence Level : {conf_label}\n"
180
- f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n\n"
 
 
181
  f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
182
- f"Please provide a structured clinical interpretation of these findings."
 
183
  )
184
  try:
185
  mdl, tok = get_llm()
@@ -191,7 +212,7 @@ def _llm_report(p_mean: float, per_model: list) -> str:
191
  inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
192
  with torch.no_grad():
193
  out = mdl.generate(
194
- **inputs, max_new_tokens=512, temperature=0.3,
195
  do_sample=True, pad_token_id=tok.eos_token_id,
196
  )
197
  generated = out[0][inputs["input_ids"].shape[1]:]
@@ -473,11 +494,21 @@ def run_gcn(file_path):
473
  consensus = sum(1 for _, p in per_model if p > 0.5)
474
  conf = max(p_mean, 1 - p_mean) * 100
475
 
 
476
  try:
 
 
 
 
 
 
 
 
 
477
  sal_img = _saliency_figure(
478
- _compute_saliency(bw_t, adj_t, models), p_mean,
479
- net_names=atlas_cfg["net_names"],
480
- net_bounds=atlas_cfg["net_bounds"],
481
  net_colors=atlas_cfg["net_colors"],
482
  )
483
  except Exception:
@@ -596,7 +627,7 @@ LOSO AUC = 0.7260 · 1,102 held-out subjects · 20 acquisition sites
596
  AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
597
 
598
  # LLM clinical interpretation
599
- llm_text = _llm_report(p_mean, per_model)
600
  report += f"""
601
  <div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
602
  <div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">
 
141
  "You are a clinical AI assistant specializing in functional MRI brain "
142
  "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
143
  "You interpret outputs from a validated graph neural network (GCN) trained on "
144
+ "the ABIDE I dataset (1,102 subjects, 20 acquisition sites) and provide structured "
145
+ "clinical summaries for neurologists and psychiatrists. "
146
+ "CRITICAL RULES: (1) Only reference brain networks, connectivity patterns, and "
147
+ "statistics that are explicitly provided in the input report — do NOT invent or "
148
+ "hallucinate network names, connectivity findings, or numeric values. "
149
+ "(2) Base every observation directly on the per-network saliency scores and "
150
+ "ensemble probabilities given in the input. (3) If a network is not listed in the "
151
+ "input, do not mention it. (4) Always clarify findings are AI-assisted and require "
152
+ "full clinical assessment. You do not make a diagnosis."
153
  )
154
  _llm_cache = None
155
 
 
167
  _llm_cache = (mdl, tok)
168
  return _llm_cache
169
 
170
+ def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
171
  consensus = sum(1 for _, p in per_model if p > 0.5)
172
  per_model_str = "\n".join(
173
  f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
 
178
  "LOW / UNCERTAIN" if p_mean >= 0.4 else
179
  "MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
180
  )
181
+
182
+ sal_section = ""
183
+ if net_saliency:
184
+ sorted_nets = sorted(net_saliency.items(), key=lambda x: x[1], reverse=True)
185
+ sal_lines = "\n".join(
186
+ f" {name}: {score:.5f}" for name, score in sorted_nets
187
+ )
188
+ sal_section = (
189
+ f"\nPer-Network Gradient Saliency (ranked high→low, actual GCN values):\n"
190
+ f"{sal_lines}\n"
191
+ f"[ONLY reference these networks with these exact values — no others.]\n"
192
+ )
193
+
194
  user_msg = (
195
  f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
196
+ f"Dataset : ABIDE I · 1,102 subjects · 20 acquisition sites\n"
197
+ f"p(ASD) : {p_mean:.3f}\n"
198
+ f"Confidence Level : {conf_label}\n"
199
+ f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n"
200
+ f"{sal_section}\n"
201
  f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
202
+ f"Provide a structured clinical interpretation referencing ONLY the networks "
203
+ f"and values listed above. Do not mention any network not in this report."
204
  )
205
  try:
206
  mdl, tok = get_llm()
 
212
  inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
213
  with torch.no_grad():
214
  out = mdl.generate(
215
+ **inputs, max_new_tokens=512, temperature=0.1,
216
  do_sample=True, pad_token_id=tok.eos_token_id,
217
  )
218
  generated = out[0][inputs["input_ids"].shape[1]:]
 
494
  consensus = sum(1 for _, p in per_model if p > 0.5)
495
  conf = max(p_mean, 1 - p_mean) * 100
496
 
497
+ net_saliency = None
498
  try:
499
+ sal = _compute_saliency(bw_t, adj_t, models)
500
+ net_names = atlas_cfg["net_names"]
501
+ net_bounds = atlas_cfg["net_bounds"]
502
+ # aggregate ROI-level saliency to network-level importance scores
503
+ net_imp = np.array([
504
+ sal[s:e, :].mean() + sal[:, s:e].mean()
505
+ for s, e in zip(net_bounds[:-1], net_bounds[1:])
506
+ ])
507
+ net_saliency = dict(zip(net_names, net_imp.tolist()))
508
  sal_img = _saliency_figure(
509
+ sal, p_mean,
510
+ net_names=net_names,
511
+ net_bounds=net_bounds,
512
  net_colors=atlas_cfg["net_colors"],
513
  )
514
  except Exception:
 
627
  AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
628
 
629
  # LLM clinical interpretation
630
+ llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
631
  report += f"""
632
  <div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
633
  <div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">