Yatsuiii commited on
Commit
9466bf3
·
verified ·
1 Parent(s): a3ca41c

Wire Qwen2.5-7B LoRA (AMD MI300X) into analysis report

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py CHANGED
@@ -118,6 +118,71 @@ def preprocess(bold):
118
  bw = _windows(bold)
119
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # ── model loading ──────────────────────────────────────────────────────────
122
 
123
  _model_cache: dict[str, list] = {}
@@ -514,6 +579,16 @@ LOSO AUC = 0.7260 · 1,102 held-out subjects · 20 acquisition sites
514
  <div style="border-top:1px solid #252a35;padding-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
515
  AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
516
 
 
 
 
 
 
 
 
 
 
 
517
  return verdict, ensemble, report, sal_img
518
 
519
 
 
118
  bw = _windows(bold)
119
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
120
 
121
+ # ── LLM (Qwen2.5-7B LoRA fine-tuned on AMD MI300X) ────────────────────────
122
+
123
+ _LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
124
+ _SYSTEM_PROMPT = (
125
+ "You are a clinical AI assistant specializing in functional MRI brain "
126
+ "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
127
+ "You interpret outputs from a validated graph neural network (GCN) trained on "
128
+ "the ABIDE I dataset and provide structured clinical summaries for neurologists "
129
+ "and psychiatrists. Your reports are informative and evidence-based but always "
130
+ "clarify that findings are AI-assisted and should be integrated with full "
131
+ "clinical assessment. You do not make a diagnosis."
132
+ )
133
+ _llm_cache = None
134
+
135
+ def get_llm():
136
+ global _llm_cache
137
+ if _llm_cache is not None:
138
+ return _llm_cache
139
+ from transformers import AutoModelForCausalLM, AutoTokenizer
140
+ tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
141
+ tok.pad_token = tok.eos_token
142
+ mdl = AutoModelForCausalLM.from_pretrained(
143
+ _LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto"
144
+ )
145
+ mdl.eval()
146
+ _llm_cache = (mdl, tok)
147
+ return _llm_cache
148
+
149
+ def _llm_report(p_mean: float, per_model: list) -> str:
150
+ consensus = sum(1 for _, p in per_model if p > 0.5)
151
+ per_model_str = "\n".join(
152
+ f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
153
+ )
154
+ conf_label = (
155
+ "HIGH" if p_mean >= 0.75 else
156
+ "MODERATE" if p_mean >= 0.6 else
157
+ "LOW / UNCERTAIN" if p_mean >= 0.4 else
158
+ "MODERATE (TC)" if p_mean >= 0.25 else "HIGH (TC)"
159
+ )
160
+ user_msg = (
161
+ f"Brain Connectivity GCN Analysis Report\n{'='*40}\n"
162
+ f"p(ASD) : {p_mean:.3f}\n"
163
+ f"Confidence Level : {conf_label}\n"
164
+ f"Model Consensus : {consensus}/{len(per_model)} site-blind models predict ASD\n\n"
165
+ f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
166
+ f"Please provide a structured clinical interpretation of these findings."
167
+ )
168
+ try:
169
+ mdl, tok = get_llm()
170
+ messages = [
171
+ {"role": "system", "content": _SYSTEM_PROMPT},
172
+ {"role": "user", "content": user_msg},
173
+ ]
174
+ text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
175
+ inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
176
+ with torch.no_grad():
177
+ out = mdl.generate(
178
+ **inputs, max_new_tokens=512, temperature=0.3,
179
+ do_sample=True, pad_token_id=tok.eos_token_id,
180
+ )
181
+ generated = out[0][inputs["input_ids"].shape[1]:]
182
+ return tok.decode(generated, skip_special_tokens=True).strip()
183
+ except Exception as e:
184
+ return f"[LLM unavailable: {e}]"
185
+
186
  # ── model loading ──────────────────────────────────────────────────────────
187
 
188
  _model_cache: dict[str, list] = {}
 
579
  <div style="border-top:1px solid #252a35;padding-top:10px;color:#5e6675;font-size:0.74rem;line-height:1.5">
580
  AI-assisted screening only · Not a clinical diagnosis · Findings must be integrated with ADOS-2, ADI-R, and full developmental history · Refer to licensed neuropsychologist for formal evaluation.</div></div>"""
581
 
582
+ # LLM clinical interpretation
583
+ llm_text = _llm_report(p_mean, per_model)
584
+ report += f"""
585
+ <div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
586
+ <div style="color:#2dc653;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:10px;font-weight:600">
587
+ Qwen2.5-7B Clinical Interpretation · Fine-tuned on AMD MI300X
588
+ </div>
589
+ <div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
590
+ </div>"""
591
+
592
  return verdict, ensemble, report, sal_img
593
 
594