Delete app_with_llm.py with huggingface_hub
Browse files- app_with_llm.py +0 -363
app_with_llm.py
DELETED
|
@@ -1,363 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
|
| 3 |
-
Full pipeline: Adversarial GCN + Qwen2.5-7B fine-tuned on AMD MI300X.
|
| 4 |
-
"""
|
| 5 |
-
from __future__ import annotations
|
| 6 |
-
|
| 7 |
-
import io
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
import numpy as np
|
| 11 |
-
import torch
|
| 12 |
-
import gradio as gr
|
| 13 |
-
|
| 14 |
-
_WINDOW_LEN = 50
|
| 15 |
-
_STEP = 3
|
| 16 |
-
_MAX_WINDOWS = 30
|
| 17 |
-
_FC_THRESHOLD = 0.2
|
| 18 |
-
|
| 19 |
-
_CKPTS = {
|
| 20 |
-
"NYU": Path("checkpoints/nyu.ckpt"),
|
| 21 |
-
"USM": Path("checkpoints/usm.ckpt"),
|
| 22 |
-
"UCLA": Path("checkpoints/ucla.ckpt"),
|
| 23 |
-
"UM": Path("checkpoints/um.ckpt"),
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
_LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
|
| 27 |
-
|
| 28 |
-
SYSTEM_PROMPT = (
|
| 29 |
-
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 30 |
-
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
| 31 |
-
"You interpret outputs from a validated graph neural network (GCN) trained on "
|
| 32 |
-
"the ABIDE I dataset and provide structured clinical summaries for neurologists "
|
| 33 |
-
"and psychiatrists. Your reports are informative and evidence-based but always "
|
| 34 |
-
"clarify that findings are AI-assisted and should be integrated with full "
|
| 35 |
-
"clinical assessment. You do not make a diagnosis."
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
# ── preprocessing ──────────────────────────────────────────────────────────
|
| 39 |
-
|
| 40 |
-
def _zscore(bold):
|
| 41 |
-
mean = bold.mean(0, keepdims=True)
|
| 42 |
-
std = bold.std(0, keepdims=True)
|
| 43 |
-
std[std < 1e-8] = 1.0
|
| 44 |
-
return ((bold - mean) / std).astype(np.float32)
|
| 45 |
-
|
| 46 |
-
def _fc(bold):
|
| 47 |
-
fc = np.corrcoef(bold.T).astype(np.float32)
|
| 48 |
-
np.nan_to_num(fc, copy=False)
|
| 49 |
-
return fc
|
| 50 |
-
|
| 51 |
-
def _windows(bold):
|
| 52 |
-
T, N = bold.shape
|
| 53 |
-
starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
|
| 54 |
-
w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
|
| 55 |
-
if len(w) >= _MAX_WINDOWS:
|
| 56 |
-
return w[:_MAX_WINDOWS]
|
| 57 |
-
return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
|
| 58 |
-
|
| 59 |
-
def preprocess(bold):
|
| 60 |
-
bold = _zscore(bold)
|
| 61 |
-
fc = _fc(bold)
|
| 62 |
-
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
|
| 63 |
-
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
|
| 64 |
-
bw = _windows(bold)
|
| 65 |
-
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 66 |
-
|
| 67 |
-
# ── GCN model loading ──────────────────────────────────────────────────────
|
| 68 |
-
|
| 69 |
-
_models: list | None = None
|
| 70 |
-
|
| 71 |
-
def get_models():
|
| 72 |
-
global _models
|
| 73 |
-
if _models is not None:
|
| 74 |
-
return _models
|
| 75 |
-
from brain_gcn.tasks import ClassificationTask
|
| 76 |
-
_models = []
|
| 77 |
-
for site, ckpt in _CKPTS.items():
|
| 78 |
-
if not ckpt.exists():
|
| 79 |
-
continue
|
| 80 |
-
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
|
| 81 |
-
task.eval()
|
| 82 |
-
_models.append((site, task))
|
| 83 |
-
return _models
|
| 84 |
-
|
| 85 |
-
# ── LLM loading ────────────────────────────────────────────────────────────
|
| 86 |
-
|
| 87 |
-
_llm = None
|
| 88 |
-
|
| 89 |
-
def get_llm():
|
| 90 |
-
global _llm
|
| 91 |
-
if _llm is not None:
|
| 92 |
-
return _llm
|
| 93 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 94 |
-
print(f"Loading LLM: {_LLM_MODEL}")
|
| 95 |
-
tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
|
| 96 |
-
tok.pad_token = tok.eos_token
|
| 97 |
-
mdl = AutoModelForCausalLM.from_pretrained(
|
| 98 |
-
_LLM_MODEL,
|
| 99 |
-
torch_dtype=torch.bfloat16,
|
| 100 |
-
device_map="auto",
|
| 101 |
-
)
|
| 102 |
-
mdl.eval()
|
| 103 |
-
_llm = (mdl, tok)
|
| 104 |
-
return _llm
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _llm_report(p_mean: float, per_model: list) -> str:
|
| 108 |
-
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 109 |
-
per_model_str = "\n".join(
|
| 110 |
-
f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
|
| 111 |
-
)
|
| 112 |
-
conf_label = (
|
| 113 |
-
"HIGH" if p_mean >= 0.75 else
|
| 114 |
-
"MODERATE" if p_mean >= 0.6 else
|
| 115 |
-
"LOW / UNCERTAIN" if p_mean >= 0.4 else
|
| 116 |
-
"MODERATE (TC)" if p_mean >= 0.25 else
|
| 117 |
-
"HIGH (TC)"
|
| 118 |
-
)
|
| 119 |
-
user_msg = (
|
| 120 |
-
f"Brain Connectivity GCN Analysis Report\n"
|
| 121 |
-
f"{'='*40}\n"
|
| 122 |
-
f"p(ASD) : {p_mean:.3f}\n"
|
| 123 |
-
f"Confidence Level : {conf_label}\n"
|
| 124 |
-
f"Model Consensus : {consensus}/4 site-blind models predict ASD\n\n"
|
| 125 |
-
f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
|
| 126 |
-
f"Please provide a structured clinical interpretation of these findings."
|
| 127 |
-
)
|
| 128 |
-
try:
|
| 129 |
-
mdl, tok = get_llm()
|
| 130 |
-
messages = [
|
| 131 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 132 |
-
{"role": "user", "content": user_msg},
|
| 133 |
-
]
|
| 134 |
-
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 135 |
-
inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
|
| 136 |
-
with torch.no_grad():
|
| 137 |
-
out = mdl.generate(
|
| 138 |
-
**inputs,
|
| 139 |
-
max_new_tokens=512,
|
| 140 |
-
temperature=0.3,
|
| 141 |
-
do_sample=True,
|
| 142 |
-
pad_token_id=tok.eos_token_id,
|
| 143 |
-
)
|
| 144 |
-
generated = out[0][inputs["input_ids"].shape[1]:]
|
| 145 |
-
return tok.decode(generated, skip_special_tokens=True).strip()
|
| 146 |
-
except Exception as e:
|
| 147 |
-
return f"LLM unavailable: {e}"
|
| 148 |
-
|
| 149 |
-
# ── gradient saliency ──────────────────────────────────────────────────────
|
| 150 |
-
|
| 151 |
-
def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
|
| 152 |
-
maps = []
|
| 153 |
-
for _, task in models:
|
| 154 |
-
adj = adj_t.clone().requires_grad_(True)
|
| 155 |
-
logits = task.model(bw_t, adj)
|
| 156 |
-
p = torch.softmax(logits, -1)[0, 1]
|
| 157 |
-
p.backward()
|
| 158 |
-
maps.append(adj.grad[0].abs().detach().numpy())
|
| 159 |
-
sal = np.mean(maps, axis=0)
|
| 160 |
-
sal = (sal + sal.T) / 2
|
| 161 |
-
return sal
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def _saliency_figure(sal: np.ndarray, p_mean: float):
|
| 165 |
-
import matplotlib
|
| 166 |
-
matplotlib.use("Agg")
|
| 167 |
-
import matplotlib.pyplot as plt
|
| 168 |
-
from PIL import Image
|
| 169 |
-
|
| 170 |
-
thresh = np.percentile(sal, 95)
|
| 171 |
-
sal_top = np.where(sal >= thresh, sal, 0.0)
|
| 172 |
-
roi_imp = sal.sum(1)
|
| 173 |
-
top20 = roi_imp.argsort()[-20:][::-1]
|
| 174 |
-
verdict_color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261"
|
| 175 |
-
|
| 176 |
-
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
|
| 177 |
-
fig.patch.set_facecolor("#0d0d0d")
|
| 178 |
-
|
| 179 |
-
ax = axes[0]
|
| 180 |
-
ax.set_facecolor("#111")
|
| 181 |
-
im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
|
| 182 |
-
ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
|
| 183 |
-
ax.set_xlabel("ROI index", color="#777", fontsize=9)
|
| 184 |
-
ax.set_ylabel("ROI index", color="#777", fontsize=9)
|
| 185 |
-
ax.tick_params(colors="#555", labelsize=8)
|
| 186 |
-
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 187 |
-
cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
|
| 188 |
-
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
|
| 189 |
-
for spine in ax.spines.values():
|
| 190 |
-
spine.set_color("#333")
|
| 191 |
-
|
| 192 |
-
ax2 = axes[1]
|
| 193 |
-
ax2.set_facecolor("#111")
|
| 194 |
-
ax2.barh(range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none")
|
| 195 |
-
ax2.set_yticks(range(20))
|
| 196 |
-
ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
|
| 197 |
-
ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
|
| 198 |
-
ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
|
| 199 |
-
ax2.tick_params(colors="#555", labelsize=8)
|
| 200 |
-
ax2.invert_yaxis()
|
| 201 |
-
for spine in ["top", "right"]:
|
| 202 |
-
ax2.spines[spine].set_visible(False)
|
| 203 |
-
for spine in ["bottom", "left"]:
|
| 204 |
-
ax2.spines[spine].set_color("#333")
|
| 205 |
-
|
| 206 |
-
fig.suptitle(
|
| 207 |
-
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
|
| 208 |
-
color="#888", fontsize=10, y=1.02,
|
| 209 |
-
)
|
| 210 |
-
plt.tight_layout()
|
| 211 |
-
buf = io.BytesIO()
|
| 212 |
-
plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
|
| 213 |
-
plt.close(fig)
|
| 214 |
-
buf.seek(0)
|
| 215 |
-
return Image.open(buf).copy()
|
| 216 |
-
|
| 217 |
-
# ── inference ──────────────────────────────────────────────────────────────
|
| 218 |
-
|
| 219 |
-
def run_gcn(file_path: str | None):
|
| 220 |
-
if file_path is None:
|
| 221 |
-
return "", "", "", None, ""
|
| 222 |
-
|
| 223 |
-
path = Path(file_path)
|
| 224 |
-
try:
|
| 225 |
-
if path.suffix == ".npz":
|
| 226 |
-
d = np.load(path, allow_pickle=True)
|
| 227 |
-
fc = d["mean_fc"].astype(np.float32)
|
| 228 |
-
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
|
| 229 |
-
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
|
| 230 |
-
bw = d["bold_windows"].astype(np.float32)
|
| 231 |
-
if len(bw) >= _MAX_WINDOWS:
|
| 232 |
-
bw = bw[:_MAX_WINDOWS]
|
| 233 |
-
else:
|
| 234 |
-
bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
|
| 235 |
-
bw_t = torch.FloatTensor(bw).unsqueeze(0)
|
| 236 |
-
adj_t = torch.FloatTensor(adj).unsqueeze(0)
|
| 237 |
-
else:
|
| 238 |
-
bold = np.loadtxt(path, dtype=np.float32)
|
| 239 |
-
if bold.ndim != 2 or bold.shape[1] != 200:
|
| 240 |
-
return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None, ""
|
| 241 |
-
bw_t, adj_t = preprocess(bold)
|
| 242 |
-
except Exception as e:
|
| 243 |
-
return f"⚠️ Error loading file: {e}", "", "", None, ""
|
| 244 |
-
|
| 245 |
-
models = get_models()
|
| 246 |
-
|
| 247 |
-
per_model = []
|
| 248 |
-
with torch.no_grad():
|
| 249 |
-
for site, task in models:
|
| 250 |
-
logits = task(bw_t, adj_t)
|
| 251 |
-
p = torch.softmax(logits, -1)[0, 1].item()
|
| 252 |
-
per_model.append((site, p))
|
| 253 |
-
|
| 254 |
-
p_mean = float(np.mean([p for _, p in per_model]))
|
| 255 |
-
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 256 |
-
conf = max(p_mean, 1 - p_mean) * 100
|
| 257 |
-
|
| 258 |
-
try:
|
| 259 |
-
sal = _compute_saliency(bw_t, adj_t, models)
|
| 260 |
-
sal_img = _saliency_figure(sal, p_mean)
|
| 261 |
-
except Exception:
|
| 262 |
-
sal_img = None
|
| 263 |
-
|
| 264 |
-
# Verdict
|
| 265 |
-
if p_mean > 0.6:
|
| 266 |
-
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 267 |
-
<div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
|
| 268 |
-
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{consensus}/4</b> site-blind models agree</div>
|
| 269 |
-
</div>"""
|
| 270 |
-
elif p_mean < 0.4:
|
| 271 |
-
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 272 |
-
<div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
|
| 273 |
-
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
|
| 274 |
-
</div>"""
|
| 275 |
-
else:
|
| 276 |
-
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 277 |
-
<div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
|
| 278 |
-
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | Model disagreement — clinical review required</div>
|
| 279 |
-
</div>"""
|
| 280 |
-
|
| 281 |
-
# Ensemble breakdown
|
| 282 |
-
rows = ""
|
| 283 |
-
for site, p in per_model:
|
| 284 |
-
lbl = "ASD" if p > 0.5 else "TC"
|
| 285 |
-
color = "#e63946" if p > 0.5 else "#2dc653"
|
| 286 |
-
bar_w = int(p * 100)
|
| 287 |
-
rows += f"""<tr>
|
| 288 |
-
<td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
|
| 289 |
-
<td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
|
| 290 |
-
<div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
|
| 291 |
-
<td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
|
| 292 |
-
<td style="padding:8px 12px;color:#888">p={p:.3f}</td>
|
| 293 |
-
</tr>"""
|
| 294 |
-
|
| 295 |
-
ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
|
| 296 |
-
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
|
| 297 |
-
<table style="width:100%;border-collapse:collapse">{rows}</table>
|
| 298 |
-
<div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree · LOSO AUC = 0.7872 across 529 held-out subjects</div>
|
| 299 |
-
</div>"""
|
| 300 |
-
|
| 301 |
-
# LLM clinical report
|
| 302 |
-
llm_text = _llm_report(p_mean, per_model)
|
| 303 |
-
report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
|
| 304 |
-
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Report — Qwen2.5-7B fine-tuned on AMD Instinct MI300X</div>
|
| 305 |
-
<div style="color:#ddd;font-size:0.95rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
|
| 306 |
-
<div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#555;font-size:0.78rem;margin-top:16px">
|
| 307 |
-
⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments (ADOS-2, ADI-R).
|
| 308 |
-
</div></div>"""
|
| 309 |
-
|
| 310 |
-
return verdict, ensemble, report, sal_img
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
# ── UI ─────────────────────────────────────────────────────────────────────
|
| 314 |
-
|
| 315 |
-
css = """
|
| 316 |
-
body { background: #0d0d0d; }
|
| 317 |
-
.gradio-container { max-width: 960px; margin: auto; }
|
| 318 |
-
"""
|
| 319 |
-
|
| 320 |
-
with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
|
| 321 |
-
gr.HTML("""
|
| 322 |
-
<div style="text-align:center;padding:32px 0 16px">
|
| 323 |
-
<div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
|
| 324 |
-
<div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
|
| 325 |
-
<div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
|
| 326 |
-
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
|
| 327 |
-
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
|
| 328 |
-
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
|
| 329 |
-
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
|
| 330 |
-
</div>
|
| 331 |
-
</div>
|
| 332 |
-
""")
|
| 333 |
-
|
| 334 |
-
file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
|
| 335 |
-
verdict_html = gr.HTML()
|
| 336 |
-
ensemble_html = gr.HTML()
|
| 337 |
-
|
| 338 |
-
gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
|
| 339 |
-
saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
|
| 340 |
-
|
| 341 |
-
report_html = gr.HTML()
|
| 342 |
-
|
| 343 |
-
file_input.change(
|
| 344 |
-
fn=run_gcn,
|
| 345 |
-
inputs=file_input,
|
| 346 |
-
outputs=[verdict_html, ensemble_html, report_html, saliency_img],
|
| 347 |
-
)
|
| 348 |
-
|
| 349 |
-
gr.HTML("""
|
| 350 |
-
<div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
|
| 351 |
-
Adversarial Brain-Mode GCN (k=16) · Qwen2.5-7B LoRA (AMD MI300X) · ABIDE I ·
|
| 352 |
-
<a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
|
| 353 |
-
</div>
|
| 354 |
-
""")
|
| 355 |
-
|
| 356 |
-
print("Preloading GCN models...")
|
| 357 |
-
get_models()
|
| 358 |
-
print("Preloading LLM...")
|
| 359 |
-
get_llm()
|
| 360 |
-
print("All models ready.")
|
| 361 |
-
|
| 362 |
-
if __name__ == "__main__":
|
| 363 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|