LLM: switch to HF InferenceClient (merged model, always-on)
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import io
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
import numpy as np
|
|
@@ -134,9 +135,11 @@ def preprocess(bold):
|
|
| 134 |
bw = _windows(bold)
|
| 135 |
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 136 |
|
| 137 |
-
# ── LLM (Qwen2.5-7B
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
_LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
|
| 140 |
_SYSTEM_PROMPT = (
|
| 141 |
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 142 |
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
|
@@ -151,21 +154,6 @@ _SYSTEM_PROMPT = (
|
|
| 151 |
"input, do not mention it. (4) Always clarify findings are AI-assisted and require "
|
| 152 |
"full clinical assessment. You do not make a diagnosis."
|
| 153 |
)
|
| 154 |
-
_llm_cache = None
|
| 155 |
-
|
| 156 |
-
def get_llm():
|
| 157 |
-
global _llm_cache
|
| 158 |
-
if _llm_cache is not None:
|
| 159 |
-
return _llm_cache
|
| 160 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 161 |
-
tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
|
| 162 |
-
tok.pad_token = tok.eos_token
|
| 163 |
-
mdl = AutoModelForCausalLM.from_pretrained(
|
| 164 |
-
_LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto"
|
| 165 |
-
)
|
| 166 |
-
mdl.eval()
|
| 167 |
-
_llm_cache = (mdl, tok)
|
| 168 |
-
return _llm_cache
|
| 169 |
|
| 170 |
def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
|
| 171 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
|
@@ -203,20 +191,14 @@ def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None
|
|
| 203 |
f"and values listed above. Do not mention any network not in this report."
|
| 204 |
)
|
| 205 |
try:
|
| 206 |
-
|
|
|
|
| 207 |
messages = [
|
| 208 |
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 209 |
{"role": "user", "content": user_msg},
|
| 210 |
]
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
with torch.no_grad():
|
| 214 |
-
out = mdl.generate(
|
| 215 |
-
**inputs, max_new_tokens=512, temperature=0.1,
|
| 216 |
-
do_sample=True, pad_token_id=tok.eos_token_id,
|
| 217 |
-
)
|
| 218 |
-
generated = out[0][inputs["input_ids"].shape[1]:]
|
| 219 |
-
return tok.decode(generated, skip_special_tokens=True).strip()
|
| 220 |
except Exception as e:
|
| 221 |
return f"[LLM unavailable: {e}]"
|
| 222 |
|
|
@@ -699,20 +681,8 @@ AI-assisted screening only · Not a clinical diagnosis · Findings must be integ
|
|
| 699 |
|
| 700 |
# LLM clinical interpretation (only attempt if GPU is available)
|
| 701 |
import os
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
|
| 705 |
-
llm_block = f'<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>'
|
| 706 |
-
else:
|
| 707 |
-
llm_block = """
|
| 708 |
-
<div style="color:#8b95a7;font-size:0.84rem;line-height:1.6">
|
| 709 |
-
Qwen2.5-7B LoRA interpreter is active — fine-tuned on AMD Instinct MI300X (192 GB HBM3, ROCm 7.0, bf16).
|
| 710 |
-
GPU inference is required to run it in real-time. The full model is available at
|
| 711 |
-
<span style="color:#fb923c">Yatsuiii/asd-interpreter-lora</span> on Hugging Face.
|
| 712 |
-
<br><br>
|
| 713 |
-
<span style="color:#5e6675">Clinical interpretation pipeline: GCN ensemble → per-network saliency extraction →
|
| 714 |
-
Qwen2.5-7B generates grounded clinical summary referencing only the actual saliency values.</span>
|
| 715 |
-
</div>"""
|
| 716 |
report += f"""
|
| 717 |
<div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
|
| 718 |
<div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">
|
|
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import io
|
| 7 |
+
import os
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
import numpy as np
|
|
|
|
| 135 |
bw = _windows(bold)
|
| 136 |
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 137 |
|
| 138 |
+
# ── LLM (Qwen2.5-7B fine-tuned on AMD MI300X, served via HF Inference API) ─
|
| 139 |
+
|
| 140 |
+
_LLM_MODEL = "Yatsuiii/asd-interpreter-merged"
|
| 141 |
+
_HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
| 142 |
|
|
|
|
| 143 |
_SYSTEM_PROMPT = (
|
| 144 |
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 145 |
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
|
|
|
| 154 |
"input, do not mention it. (4) Always clarify findings are AI-assisted and require "
|
| 155 |
"full clinical assessment. You do not make a diagnosis."
|
| 156 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
|
| 159 |
consensus = sum(1 for _, p in per_model if p > 0.5)
|
|
|
|
| 191 |
f"and values listed above. Do not mention any network not in this report."
|
| 192 |
)
|
| 193 |
try:
|
| 194 |
+
from huggingface_hub import InferenceClient
|
| 195 |
+
client = InferenceClient(model=_LLM_MODEL, token=_HF_TOKEN or None)
|
| 196 |
messages = [
|
| 197 |
{"role": "system", "content": _SYSTEM_PROMPT},
|
| 198 |
{"role": "user", "content": user_msg},
|
| 199 |
]
|
| 200 |
+
response = client.chat_completion(messages=messages, max_tokens=512, temperature=0.1)
|
| 201 |
+
return response.choices[0].message.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
except Exception as e:
|
| 203 |
return f"[LLM unavailable: {e}]"
|
| 204 |
|
|
|
|
| 681 |
|
| 682 |
# LLM clinical interpretation (only attempt if GPU is available)
|
| 683 |
import os
|
| 684 |
+
llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
|
| 685 |
+
llm_block = f'<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
report += f"""
|
| 687 |
<div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
|
| 688 |
<div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">
|