Yatsuiii commited on
Commit
b050a20
·
verified ·
1 Parent(s): c15db7e

LLM: switch to HF InferenceClient (merged model, always-on)

Browse files
Files changed (1) hide show
  1. app.py +11 -41
app.py CHANGED
@@ -4,6 +4,7 @@ BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
4
  from __future__ import annotations
5
 
6
  import io
 
7
  from pathlib import Path
8
 
9
  import numpy as np
@@ -134,9 +135,11 @@ def preprocess(bold):
134
  bw = _windows(bold)
135
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
136
 
137
- # ── LLM (Qwen2.5-7B LoRA fine-tuned on AMD MI300X) ────────────────────────
 
 
 
138
 
139
- _LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
140
  _SYSTEM_PROMPT = (
141
  "You are a clinical AI assistant specializing in functional MRI brain "
142
  "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
@@ -151,21 +154,6 @@ _SYSTEM_PROMPT = (
151
  "input, do not mention it. (4) Always clarify findings are AI-assisted and require "
152
  "full clinical assessment. You do not make a diagnosis."
153
  )
154
- _llm_cache = None
155
-
156
- def get_llm():
157
- global _llm_cache
158
- if _llm_cache is not None:
159
- return _llm_cache
160
- from transformers import AutoModelForCausalLM, AutoTokenizer
161
- tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
162
- tok.pad_token = tok.eos_token
163
- mdl = AutoModelForCausalLM.from_pretrained(
164
- _LLM_MODEL, torch_dtype=torch.bfloat16, device_map="auto"
165
- )
166
- mdl.eval()
167
- _llm_cache = (mdl, tok)
168
- return _llm_cache
169
 
170
  def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
171
  consensus = sum(1 for _, p in per_model if p > 0.5)
@@ -203,20 +191,14 @@ def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None
203
  f"and values listed above. Do not mention any network not in this report."
204
  )
205
  try:
206
- mdl, tok = get_llm()
 
207
  messages = [
208
  {"role": "system", "content": _SYSTEM_PROMPT},
209
  {"role": "user", "content": user_msg},
210
  ]
211
- text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
212
- inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
213
- with torch.no_grad():
214
- out = mdl.generate(
215
- **inputs, max_new_tokens=512, temperature=0.1,
216
- do_sample=True, pad_token_id=tok.eos_token_id,
217
- )
218
- generated = out[0][inputs["input_ids"].shape[1]:]
219
- return tok.decode(generated, skip_special_tokens=True).strip()
220
  except Exception as e:
221
  return f"[LLM unavailable: {e}]"
222
 
@@ -699,20 +681,8 @@ AI-assisted screening only · Not a clinical diagnosis · Findings must be integ
699
 
700
  # LLM clinical interpretation (only attempt if GPU is available)
701
  import os
702
- _has_gpu = torch.cuda.is_available() or (hasattr(torch, "hip") and torch.hip.is_available() if hasattr(torch, "hip") else False)
703
- if _has_gpu:
704
- llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
705
- llm_block = f'<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>'
706
- else:
707
- llm_block = """
708
- <div style="color:#8b95a7;font-size:0.84rem;line-height:1.6">
709
- Qwen2.5-7B LoRA interpreter is active — fine-tuned on AMD Instinct MI300X (192 GB HBM3, ROCm 7.0, bf16).
710
- GPU inference is required to run it in real-time. The full model is available at
711
- <span style="color:#fb923c">Yatsuiii/asd-interpreter-lora</span> on Hugging Face.
712
- <br><br>
713
- <span style="color:#5e6675">Clinical interpretation pipeline: GCN ensemble → per-network saliency extraction →
714
- Qwen2.5-7B generates grounded clinical summary referencing only the actual saliency values.</span>
715
- </div>"""
716
  report += f"""
717
  <div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
718
  <div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">
 
4
  from __future__ import annotations
5
 
6
  import io
7
+ import os
8
  from pathlib import Path
9
 
10
  import numpy as np
 
135
  bw = _windows(bold)
136
  return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
137
 
138
+ # ── LLM (Qwen2.5-7B fine-tuned on AMD MI300X, served via HF Inference API) ─
139
+
140
+ _LLM_MODEL = "Yatsuiii/asd-interpreter-merged"
141
+ _HF_TOKEN = os.environ.get("HF_TOKEN", "")
142
 
 
143
  _SYSTEM_PROMPT = (
144
  "You are a clinical AI assistant specializing in functional MRI brain "
145
  "connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
 
154
  "input, do not mention it. (4) Always clarify findings are AI-assisted and require "
155
  "full clinical assessment. You do not make a diagnosis."
156
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def _llm_report(p_mean: float, per_model: list, net_saliency: dict | None = None) -> str:
159
  consensus = sum(1 for _, p in per_model if p > 0.5)
 
191
  f"and values listed above. Do not mention any network not in this report."
192
  )
193
  try:
194
+ from huggingface_hub import InferenceClient
195
+ client = InferenceClient(model=_LLM_MODEL, token=_HF_TOKEN or None)
196
  messages = [
197
  {"role": "system", "content": _SYSTEM_PROMPT},
198
  {"role": "user", "content": user_msg},
199
  ]
200
+ response = client.chat_completion(messages=messages, max_tokens=512, temperature=0.1)
201
+ return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
202
  except Exception as e:
203
  return f"[LLM unavailable: {e}]"
204
 
 
681
 
682
  # LLM clinical interpretation (only attempt if GPU is available)
683
  import os
684
+ llm_text = _llm_report(p_mean, per_model, net_saliency=net_saliency)
685
+ llm_block = f'<div style="color:#cbd5e1;font-size:0.85rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
686
  report += f"""
687
  <div style="background:#0f1a1a;border:1px solid #1a3a3a;border-radius:8px;padding:18px 24px;margin-top:12px">
688
  <div style="display:flex;align-items:center;gap:10px;margin-bottom:10px">