5x LLM improvements: saliency grounding, anti-hallucination system prompt, temp 0.1, n_subjects, per-network scores in prompt
Browse files
app.py
CHANGED
|
@@ -141,10 +141,15 @@ _SYSTEM_PROMPT = (
|
|
| 141 |
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 142 |
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
| 143 |
"You interpret outputs from a validated graph neural network (GCN) trained on "
|
| 144 |
-
"the ABIDE I dataset
|
| 145 |
-
"
|
| 146 |
-
"
|
| 147 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
| 149 |
_llm_cache = None
|
| 150 |
|
|
@@ -162,7 +167,7 @@ def get_llm():
|
|
| 162 |
_llm_cache = (mdl, tok)
|
| 163 |
return _llm_cache
|
| 164 |
|
| 165 |
-
def _llm_report(p_mean: float, per_model: list) -> str:
|
| 166 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 167 |
per_model_str = "\n".join(
|
| 168 |
f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
|
|
@@ -173,13 +178,29 @@ def _llm_report(p_mean: float, per_model: list) -> str:
|
|
| 173 |
"LOW / UNCERTAIN" if p_mean >= 0.4 else
|
| 174 |
"MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
|
| 175 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
user_msg = (
|
| 177 |
f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
|
| 178 |
-
f"
|
| 179 |
-
f"
|
| 180 |
-
f"
|
|
|
|
|
|
|
| 181 |
f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
|
| 182 |
-
f"
|
|
|
|
| 183 |
)
|
| 184 |
try:
|
| 185 |
mdl, tok = get_llm()
|
|
@@ -191,7 +212,7 @@ def _llm_report(p_mean: float, per_model: list) -> str:
|
|
| 191 |
inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
|
| 192 |
with torch.no_grad():
|
| 193 |
out = mdl.generate(
|
| 194 |
-
**inputs, max_new_tokens=512, temperature=0.
|
| 195 |
do_sample=True, pad_token_id=tok.eos_token_id,
|
| 196 |
)
|
| 197 |
generated = out[0][inputs["input_ids"].shape[1]:]
|
|
@@ -473,11 +494,21 @@ def run_gcn(file_path):
|
|
| 473 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 474 |
conf = max(p_mean, 1 - p_mean) * 100
|
| 475 |
|
|
|
|
| 476 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
sal_img = _saliency_figure(
|
| 478 |
-
|
| 479 |
-
net_names=
|
| 480 |
-
net_bounds=
|
| 481 |
net_colors=atlas_cfg["net_colors"],
|
| 482 |
)
|
| 483 |
except Exception:
|
|
@@ -596,7 +627,7 @@ LOSO AUC = 0.7260 · 1,102 held-out subjects · 20 acquisition sites
|
|
| 596 |
AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
|
| 597 |
|
| 598 |
# LLM clinical interpretation
|
| 599 |
-
llm_text = _llm_report(p_mean, per_model)
|
| 600 |
report += f"""
|
| 601 |
<div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
|
| 602 |
<div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">
|
|
|
|
| 141 |
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 142 |
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
| 143 |
"You interpret outputs from a validated graph neural network (GCN) trained on "
|
| 144 |
+
"the ABIDE I dataset (1,102 subjects, 20 acquisition sites) and provide structured "
|
| 145 |
+
"clinical summaries for neurologists and psychiatrists. "
|
| 146 |
+
"CRITICAL RULES: (1) Only reference brain networks, connectivity patterns, and "
|
| 147 |
+
"statistics that are explicitly provided in the input report — do NOT invent or "
|
| 148 |
+
"hallucinate network names, connectivity findings, or numeric values. "
|
| 149 |
+
"(2) Base every observation directly on the per-network saliency scores and "
|
| 150 |
+
"ensemble probabilities given in the input. (3) If a network is not listed in the "
|
| 151 |
+
"input, do not mention it. (4) Always clarify findings are AI-assisted and require "
|
| 152 |
+
"full clinical assessment. You do not make a diagnosis."
|
| 153 |
)
|
| 154 |
_llm_cache = None
|
| 155 |
|
|
|
|
| 167 |
_llm_cache = (mdl, tok)
|
| 168 |
return _llm_cache
|
| 169 |
|
| 170 |
+
def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
|
| 171 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 172 |
per_model_str = "\n".join(
|
| 173 |
f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
|
|
|
|
| 178 |
"LOW / UNCERTAIN" if p_mean >= 0.4 else
|
| 179 |
"MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
|
| 180 |
)
|
| 181 |
+
|
| 182 |
+
sal_section = ""
|
| 183 |
+
if net_saliency:
|
| 184 |
+
sorted_nets = sorted(net_saliency.items(), key=lambda x: x[1], reverse=True)
|
| 185 |
+
sal_lines = "\n".join(
|
| 186 |
+
f" {name}: {score:.5f}" for name, score in sorted_nets
|
| 187 |
+
)
|
| 188 |
+
sal_section = (
|
| 189 |
+
f"\nPer-Network Gradient Saliency (ranked high→low, actual GCN values):\n"
|
| 190 |
+
f"{sal_lines}\n"
|
| 191 |
+
f"[ONLY reference these networks with these exact values — no others.]\n"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
user_msg = (
|
| 195 |
f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
|
| 196 |
+
f"Dataset : ABIDE I · 1,102 subjects · 20 acquisition sites\n"
|
| 197 |
+
f"p(ASD) : {p_mean:.3f}\n"
|
| 198 |
+
f"Confidence Level : {conf_label}\n"
|
| 199 |
+
f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n"
|
| 200 |
+
f"{sal_section}\n"
|
| 201 |
f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
|
| 202 |
+
f"Provide a structured clinical interpretation referencing ONLY the networks "
|
| 203 |
+
f"and values listed above. Do not mention any network not in this report."
|
| 204 |
)
|
| 205 |
try:
|
| 206 |
mdl, tok = get_llm()
|
|
|
|
| 212 |
inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
|
| 213 |
with torch.no_grad():
|
| 214 |
out = mdl.generate(
|
| 215 |
+
**inputs, max_new_tokens=512, temperature=0.1,
|
| 216 |
do_sample=True, pad_token_id=tok.eos_token_id,
|
| 217 |
)
|
| 218 |
generated = out[0][inputs["input_ids"].shape[1]:]
|
|
|
|
| 494 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 495 |
conf = max(p_mean, 1 - p_mean) * 100
|
| 496 |
|
| 497 |
+
net_saliency = None
|
| 498 |
try:
|
| 499 |
+
sal = _compute_saliency(bw_t, adj_t, models)
|
| 500 |
+
net_names = atlas_cfg["net_names"]
|
| 501 |
+
net_bounds = atlas_cfg["net_bounds"]
|
| 502 |
+
# aggregate ROI-level saliency to network-level importance scores
|
| 503 |
+
net_imp = np.array([
|
| 504 |
+
sal[s:e, :].mean() + sal[:, s:e].mean()
|
| 505 |
+
for s, e in zip(net_bounds[:-1], net_bounds[1:])
|
| 506 |
+
])
|
| 507 |
+
net_saliency = dict(zip(net_names, net_imp.tolist()))
|
| 508 |
sal_img = _saliency_figure(
|
| 509 |
+
sal, p_mean,
|
| 510 |
+
net_names=net_names,
|
| 511 |
+
net_bounds=net_bounds,
|
| 512 |
net_colors=atlas_cfg["net_colors"],
|
| 513 |
)
|
| 514 |
except Exception:
|
|
|
|
| 627 |
AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
|
| 628 |
|
| 629 |
# LLM clinical interpretation
|
| 630 |
+
llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
|
| 631 |
report += f"""
|
| 632 |
<div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
|
| 633 |
<div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">
|