Wire Qwen2.5-7B LoRA (AMD MI300X) into analysis report
Browse files
app.py
CHANGED
|
@@ -118,6 +118,71 @@ def preprocess(bold):
|
|
| 118 |
bw = _windows(bold)
|
| 119 |
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
# ── model loading ──────────────────────────────────────────────────────────
|
| 122 |
|
| 123 |
_model_cache: dict[str, list] = {}
|
|
@@ -514,6 +579,16 @@ LOSO AUC = 0.7260 · 1,102 held-out subjects · 20 acquisition sites
|
|
| 514 |
<div style="border-top:1px solid #252a35;padding-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
|
| 515 |
AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
|
| 516 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
return verdict, ensemble, report, sal_img
|
| 518 |
|
| 519 |
|
|
|
|
| 118 |
bw = _windows(bold)
|
| 119 |
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 120 |
|
| 121 |
+
# ── LLM (Qwen2.5-7B LoRA fine-tuned on AMD MI300X) ────────────────────────
|
| 122 |
+
|
| 123 |
+
_LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
|
| 124 |
+
_SYSTEM_PROMPT = (
|
| 125 |
+
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 126 |
+
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
| 127 |
+
"You interpret outputs from a validated graph neural network (GCN) trained on "
|
| 128 |
+
"the ABIDE I dataset and provide structured clinical summaries for neurologists "
|
| 129 |
+
"and psychiatrists. Your reports are informative and evidence-based but always "
|
| 130 |
+
"clarify that findings are AI-assisted and should be integrated with full "
|
| 131 |
+
"clinical assessment. You do not make a diagnosis."
|
| 132 |
+
)
|
| 133 |
+
_llm_cache = None
|
| 134 |
+
|
| 135 |
+
def get_llm():
|
| 136 |
+
global _llm_cache
|
| 137 |
+
if _llm_cache is not None:
|
| 138 |
+
return _llm_cache
|
| 139 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 140 |
+
tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
|
| 141 |
+
tok.pad_token = tok.eos_token
|
| 142 |
+
mdl = AutoModelForCausalLM.from_pretrained(
|
| 143 |
+
_LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto"
|
| 144 |
+
)
|
| 145 |
+
mdl.eval()
|
| 146 |
+
_llm_cache = (mdl, tok)
|
| 147 |
+
return _llm_cache
|
| 148 |
+
|
| 149 |
+
def _llm_report(p_mean: float, per_model: list) -> str:
|
| 150 |
+
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 151 |
+
per_model_str = "\n".join(
|
| 152 |
+
f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
|
| 153 |
+
)
|
| 154 |
+
conf_label = (
|
| 155 |
+
"HIGH" if p_mean >= 0.75 else
|
| 156 |
+
"MODERATE" if p_mean >= 0.6 else
|
| 157 |
+
"LOW / UNCERTAIN" if p_mean >= 0.4 else
|
| 158 |
+
"MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
|
| 159 |
+
)
|
| 160 |
+
user_msg = (
|
| 161 |
+
f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
|
| 162 |
+
f"p(ASD) : {p_mean:.3f}\n"
|
| 163 |
+
f"Confidence Level : {conf_label}\n"
|
| 164 |
+
f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n\n"
|
| 165 |
+
f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
|
| 166 |
+
f"Please provide a structured clinical interpretation of these findings."
|
| 167 |
+
)
|
| 168 |
+
try:
|
| 169 |
+
mdl, tok = get_llm()
|
| 170 |
+
messages = [
|
| 171 |
+
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 172 |
+
{"role": "user", "content": user_msg},
|
| 173 |
+
]
|
| 174 |
+
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 175 |
+
inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
out = mdl.generate(
|
| 178 |
+
**inputs, max_new_tokens=512, temperature=0.3,
|
| 179 |
+
do_sample=True, pad_token_id=tok.eos_token_id,
|
| 180 |
+
)
|
| 181 |
+
generated = out[0][inputs["input_ids"].shape[1]:]
|
| 182 |
+
return tok.decode(generated, skip_special_tokens=True).strip()
|
| 183 |
+
except Exception as e:
|
| 184 |
+
return f"[LLM unavailable: {e}]"
|
| 185 |
+
|
| 186 |
# ── model loading ──────────────────────────────────────────────────────────
|
| 187 |
|
| 188 |
_model_cache: dict[str, list] = {}
|
|
|
|
| 579 |
<div style="border-top:1px solid #252a35;padding-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
|
| 580 |
AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
|
| 581 |
|
| 582 |
+
# LLM clinical interpretation
|
| 583 |
+
llm_text = _llm_report(p_mean, per_model)
|
| 584 |
+
report += f"""
|
| 585 |
+
<div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
|
| 586 |
+
<div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">
|
| 587 |
+
Qwen2.5-7B Clinical Interpretation · Fine-tuned on AMD MI300X
|
| 588 |
+
</div>
|
| 589 |
+
<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
|
| 590 |
+
</div>"""
|
| 591 |
+
|
| 592 |
return verdict, ensemble, report, sal_img
|
| 593 |
|
| 594 |
|