Upload folder using huggingface_hub
Browse files- app.py +121 -39
- app_with_llm.py +363 -0
- checkpoints/aal_nyu.ckpt +3 -0
- checkpoints/aal_ucla.ckpt +3 -0
- checkpoints/aal_um.ckpt +3 -0
- checkpoints/aal_usm.ckpt +3 -0
app.py
CHANGED
|
@@ -17,18 +17,62 @@ _STEP = 3
|
|
| 17 |
_MAX_WINDOWS = 30
|
| 18 |
_FC_THRESHOLD = 0.2
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# ── preprocessing ──────────────────────────────────────────────────────────
|
| 33 |
|
| 34 |
def _zscore(bold):
|
|
@@ -60,21 +104,23 @@ def preprocess(bold):
|
|
| 60 |
|
| 61 |
# ── model loading ──────────────────────────────────────────────────────────
|
| 62 |
|
| 63 |
-
|
| 64 |
|
| 65 |
-
def get_models():
|
| 66 |
-
global
|
| 67 |
-
if
|
| 68 |
-
return
|
| 69 |
from brain_gcn.tasks import ClassificationTask
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
if not ckpt.exists():
|
| 73 |
continue
|
| 74 |
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
|
| 75 |
task.eval()
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
# ── gradient saliency ──────────────────────────────────────────────────────
|
| 80 |
|
|
@@ -99,7 +145,7 @@ _NET_MNI = np.array([
|
|
| 99 |
[ 14, 4, 4], # Subcortical (thalamus)
|
| 100 |
], dtype=np.float32)
|
| 101 |
|
| 102 |
-
def _saliency_figure(sal, p_mean):
|
| 103 |
import matplotlib
|
| 104 |
matplotlib.use("Agg")
|
| 105 |
import matplotlib.pyplot as plt
|
|
@@ -107,18 +153,21 @@ def _saliency_figure(sal, p_mean):
|
|
| 107 |
from mpl_toolkits.mplot3d.art3d import Line3DCollection
|
| 108 |
from PIL import Image
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
# Aggregate
|
| 113 |
net_sal = np.zeros((n_nets, n_nets))
|
| 114 |
-
for i, (s1, e1) in enumerate(zip(
|
| 115 |
-
for j, (s2, e2) in enumerate(zip(
|
| 116 |
net_sal[i, j] = sal[s1:e1, s2:e2].mean()
|
| 117 |
|
| 118 |
# Network importance: mean outgoing + incoming saliency per network
|
| 119 |
net_imp = np.array([
|
| 120 |
sal[s:e, :].mean() + sal[:, s:e].mean()
|
| 121 |
-
for s, e in zip(
|
| 122 |
])
|
| 123 |
|
| 124 |
fig = plt.figure(figsize=(18, 5.5))
|
|
@@ -137,8 +186,8 @@ def _saliency_figure(sal, p_mean):
|
|
| 137 |
|
| 138 |
ax.set_xticks(range(n_nets))
|
| 139 |
ax.set_yticks(range(n_nets))
|
| 140 |
-
ax.set_xticklabels(
|
| 141 |
-
ax.set_yticklabels(
|
| 142 |
ax.tick_params(colors="#555", length=0)
|
| 143 |
for sp in ax.spines.values():
|
| 144 |
sp.set_color("#222")
|
|
@@ -173,7 +222,7 @@ def _saliency_figure(sal, p_mean):
|
|
| 173 |
|
| 174 |
# Callout labels for top-3 cross-network edges
|
| 175 |
for rank, (score, i, j) in enumerate(top3_edges):
|
| 176 |
-
label = f"#{rank+1} {
|
| 177 |
ax.annotate(label,
|
| 178 |
xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
|
| 179 |
fontsize=6, color="#fb923c", fontweight="600",
|
|
@@ -193,9 +242,9 @@ def _saliency_figure(sal, p_mean):
|
|
| 193 |
|
| 194 |
order = net_imp.argsort()[::-1]
|
| 195 |
bars = ax2.barh(range(n_nets), net_imp[order],
|
| 196 |
-
color=[
|
| 197 |
ax2.set_yticks(range(n_nets))
|
| 198 |
-
ax2.set_yticklabels([
|
| 199 |
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
|
| 200 |
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
|
| 201 |
ax2.invert_yaxis()
|
|
@@ -250,7 +299,7 @@ def _saliency_figure(sal, p_mean):
|
|
| 250 |
ax3.set_box_aspect([1.2, 1.4, 1.0])
|
| 251 |
|
| 252 |
fig.suptitle(
|
| 253 |
-
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(
|
| 254 |
color="#444", fontsize=8.5, y=1.02,
|
| 255 |
)
|
| 256 |
plt.tight_layout()
|
|
@@ -267,6 +316,7 @@ def run_gcn(file_path):
|
|
| 267 |
return "", "", "", None
|
| 268 |
|
| 269 |
path = Path(file_path)
|
|
|
|
| 270 |
try:
|
| 271 |
if path.suffix == ".npz":
|
| 272 |
d = np.load(path, allow_pickle=True)
|
|
@@ -282,13 +332,40 @@ def run_gcn(file_path):
|
|
| 282 |
adj_t = torch.FloatTensor(adj).unsqueeze(0)
|
| 283 |
else:
|
| 284 |
bold = np.loadtxt(path, dtype=np.float32)
|
| 285 |
-
if bold.ndim != 2
|
| 286 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
bw_t, adj_t = preprocess(bold)
|
| 288 |
except Exception as e:
|
| 289 |
return f"Error loading file: {e}", "", "", None
|
| 290 |
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
per_model = []
|
| 293 |
with torch.no_grad():
|
| 294 |
for site, task in models:
|
|
@@ -300,7 +377,12 @@ def run_gcn(file_path):
|
|
| 300 |
conf = max(p_mean, 1 - p_mean) * 100
|
| 301 |
|
| 302 |
try:
|
| 303 |
-
sal_img = _saliency_figure(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
except Exception:
|
| 305 |
sal_img = None
|
| 306 |
|
|
@@ -398,7 +480,7 @@ LOSO AUC = 0.7872 · 529 held-out subjects · 4 institutions
|
|
| 398 |
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">ICD-10 Classification</div>
|
| 399 |
<div style="color:#cbd5e1;font-size:0.84rem;line-height:1.4">{icd}</div></div>
|
| 400 |
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">Ensemble Confidence</div>
|
| 401 |
-
<div style="color:#cbd5e1;font-size:0.84rem">{conf:.1f}% · p(ASD) = {p_mean:.3f} · {len(
|
| 402 |
</div>
|
| 403 |
|
| 404 |
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Impression</div>
|
|
@@ -576,7 +658,7 @@ ARCHITECTURE = """
|
|
| 576 |
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:14px 16px;min-width:130px;text-align:center;flex-shrink:0">
|
| 577 |
<div style="color:#8b95a7;font-size:0.65rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:6px">Input</div>
|
| 578 |
<div style="color:#f4f4f5;font-weight:600;font-size:0.88rem">fMRI BOLD</div>
|
| 579 |
-
<div style="color:#5e6675;font-size:0.74rem;margin-top:3px">T ×
|
| 580 |
</div>
|
| 581 |
|
| 582 |
<div style="color:#252a35;font-size:1.4rem;padding:0 6px;flex-shrink:0">→</div>
|
|
@@ -637,7 +719,7 @@ ARCHITECTURE = """
|
|
| 637 |
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
|
| 638 |
<table style="width:100%;border-collapse:collapse;font-size:0.85rem">
|
| 639 |
<tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Dataset</td><td style="padding:10px 16px;color:#cbd5e1">ABIDE I · 1,102 subjects · 17 acquisition sites</td></tr>
|
| 640 |
-
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Parcellation</td><td style="padding:10px 16px;color:#cbd5e1">CC200 (
|
| 641 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Model</td><td style="padding:10px 16px;color:#cbd5e1">AdversarialBrainModeNetwork · K=16 modes · hidden_dim=64</td></tr>
|
| 642 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Validation</td><td style="padding:10px 16px;color:#cbd5e1">LOSO AUC = <span style="color:#ef4444;font-weight:600">0.7872</span> · 529 held-out subjects · 0 confident misclassifications</td></tr>
|
| 643 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Interpretability</td><td style="padding:10px 16px;color:#cbd5e1">Real-time gradient saliency · 7-network aggregation · 3D brain surface</td></tr>
|
|
@@ -715,7 +797,7 @@ with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as dem
|
|
| 715 |
|
| 716 |
with gr.Tabs():
|
| 717 |
with gr.Tab("Analysis"):
|
| 718 |
-
file_input = gr.File(label="Upload CC200
|
| 719 |
gr.HTML("<div style='color:#8b95a7;font-size:0.7rem;text-transform:uppercase;letter-spacing:1.2px;margin:14px 0 8px;font-weight:500'>Or try a real ABIDE subject from a held-out site</div>")
|
| 720 |
with gr.Row():
|
| 721 |
btn_asd = gr.Button("ASD · Stanford 0051160", size="sm")
|
|
|
|
| 17 |
_MAX_WINDOWS = 30
|
| 18 |
_FC_THRESHOLD = 0.2
|
| 19 |
|
| 20 |
+
# ── Atlas configurations ────────────────────────────────────────────────────
|
| 21 |
+
# CC200 → Yeo 7-network parcellation (approximate ROI ordering)
|
| 22 |
+
_ATLAS_CFG = {
|
| 23 |
+
"cc200": {
|
| 24 |
+
"n_rois": 200,
|
| 25 |
+
"label": "CC200",
|
| 26 |
+
"net_names": ["DMN", "Salience", "Frontoparietal", "Sensorimotor", "Visual", "Dorsal Attn", "Subcortical"],
|
| 27 |
+
"net_bounds": [0, 38, 69, 99, 137, 165, 180, 200],
|
| 28 |
+
"net_colors": ["#e63946", "#f4a261", "#457b9d", "#2dc653", "#a8dadc", "#8b5cf6", "#6b7280"],
|
| 29 |
+
"ckpts": {
|
| 30 |
+
"NYU": Path("checkpoints/nyu.ckpt"),
|
| 31 |
+
"USM": Path("checkpoints/usm.ckpt"),
|
| 32 |
+
"UCLA": Path("checkpoints/ucla.ckpt"),
|
| 33 |
+
"UM": Path("checkpoints/um.ckpt"),
|
| 34 |
+
},
|
| 35 |
+
},
|
| 36 |
+
"aal": {
|
| 37 |
+
"n_rois": 116,
|
| 38 |
+
"label": "AAL-116",
|
| 39 |
+
# Approximate Yeo-7 parcellation for AAL-116 anatomical ordering:
|
| 40 |
+
# Frontal/FPN (1-28), Sensorimotor (29-40), DMN parietal (41-60),
|
| 41 |
+
# Temporal/DMN (61-76), Subcortical (77-90), Occipital/Visual (91-116)
|
| 42 |
+
"net_names": ["Frontoparietal", "Sensorimotor", "Dorsal Attn", "DMN", "Salience", "Subcortical", "Visual"],
|
| 43 |
+
"net_bounds": [0, 20, 34, 50, 68, 80, 92, 116],
|
| 44 |
+
"net_colors": ["#457b9d", "#2dc653", "#8b5cf6", "#e63946", "#f4a261", "#6b7280", "#a8dadc"],
|
| 45 |
+
"ckpts": {
|
| 46 |
+
"NYU": Path("checkpoints/aal_nyu.ckpt"),
|
| 47 |
+
"USM": Path("checkpoints/aal_usm.ckpt"),
|
| 48 |
+
"UCLA": Path("checkpoints/aal_ucla.ckpt"),
|
| 49 |
+
"UM": Path("checkpoints/aal_um.ckpt"),
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
"ho": {
|
| 53 |
+
"n_rois": 111,
|
| 54 |
+
"label": "Harvard-Oxford",
|
| 55 |
+
"net_names": ["Frontoparietal", "Sensorimotor", "DMN", "Salience", "Subcortical", "Visual", "Temporal"],
|
| 56 |
+
"net_bounds": [0, 18, 30, 48, 68, 80, 96, 111],
|
| 57 |
+
"net_colors": ["#457b9d", "#2dc653", "#e63946", "#f4a261", "#6b7280", "#a8dadc", "#8b5cf6"],
|
| 58 |
+
"ckpts": {
|
| 59 |
+
"NYU": Path("checkpoints/ho_nyu.ckpt"),
|
| 60 |
+
"USM": Path("checkpoints/ho_usm.ckpt"),
|
| 61 |
+
"UCLA": Path("checkpoints/ho_ucla.ckpt"),
|
| 62 |
+
"UM": Path("checkpoints/ho_um.ckpt"),
|
| 63 |
+
},
|
| 64 |
+
},
|
| 65 |
}
|
| 66 |
|
| 67 |
+
# Resolve active atlas config by ROI count
|
| 68 |
+
_ROI_TO_ATLAS = {cfg["n_rois"]: key for key, cfg in _ATLAS_CFG.items()}
|
| 69 |
+
|
| 70 |
+
# Legacy aliases kept for backward compat
|
| 71 |
+
_NET_NAMES = _ATLAS_CFG["cc200"]["net_names"]
|
| 72 |
+
_NET_BOUNDS = _ATLAS_CFG["cc200"]["net_bounds"]
|
| 73 |
+
_NET_COLORS = _ATLAS_CFG["cc200"]["net_colors"]
|
| 74 |
+
_CKPTS = _ATLAS_CFG["cc200"]["ckpts"]
|
| 75 |
+
|
| 76 |
# ── preprocessing ──────────────────────────────────────────────────────────
|
| 77 |
|
| 78 |
def _zscore(bold):
|
|
|
|
| 104 |
|
| 105 |
# ── model loading ──────────────────────────────────────────────────────────
|
| 106 |
|
| 107 |
+
_model_cache: dict[str, list] = {}
|
| 108 |
|
| 109 |
+
def get_models(atlas: str = "cc200"):
|
| 110 |
+
global _model_cache
|
| 111 |
+
if atlas in _model_cache:
|
| 112 |
+
return _model_cache[atlas]
|
| 113 |
from brain_gcn.tasks import ClassificationTask
|
| 114 |
+
cfg = _ATLAS_CFG.get(atlas, _ATLAS_CFG["cc200"])
|
| 115 |
+
models = []
|
| 116 |
+
for site, ckpt in cfg["ckpts"].items():
|
| 117 |
if not ckpt.exists():
|
| 118 |
continue
|
| 119 |
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
|
| 120 |
task.eval()
|
| 121 |
+
models.append((site, task))
|
| 122 |
+
_model_cache[atlas] = models
|
| 123 |
+
return models
|
| 124 |
|
| 125 |
# ── gradient saliency ──────────────────────────────────────────────────────
|
| 126 |
|
|
|
|
| 145 |
[ 14, 4, 4], # Subcortical (thalamus)
|
| 146 |
], dtype=np.float32)
|
| 147 |
|
| 148 |
+
def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=None):
|
| 149 |
import matplotlib
|
| 150 |
matplotlib.use("Agg")
|
| 151 |
import matplotlib.pyplot as plt
|
|
|
|
| 153 |
from mpl_toolkits.mplot3d.art3d import Line3DCollection
|
| 154 |
from PIL import Image
|
| 155 |
|
| 156 |
+
_nn = net_names if net_names is not None else _NET_NAMES
|
| 157 |
+
_nb = net_bounds if net_bounds is not None else _NET_BOUNDS
|
| 158 |
+
_nc = net_colors if net_colors is not None else _NET_COLORS
|
| 159 |
+
n_nets = len(_nn)
|
| 160 |
|
| 161 |
+
# Aggregate N×N saliency → 7×7 network-level matrix
|
| 162 |
net_sal = np.zeros((n_nets, n_nets))
|
| 163 |
+
for i, (s1, e1) in enumerate(zip(_nb[:-1], _nb[1:])):
|
| 164 |
+
for j, (s2, e2) in enumerate(zip(_nb[:-1], _nb[1:])):
|
| 165 |
net_sal[i, j] = sal[s1:e1, s2:e2].mean()
|
| 166 |
|
| 167 |
# Network importance: mean outgoing + incoming saliency per network
|
| 168 |
net_imp = np.array([
|
| 169 |
sal[s:e, :].mean() + sal[:, s:e].mean()
|
| 170 |
+
for s, e in zip(_nb[:-1], _nb[1:])
|
| 171 |
])
|
| 172 |
|
| 173 |
fig = plt.figure(figsize=(18, 5.5))
|
|
|
|
| 186 |
|
| 187 |
ax.set_xticks(range(n_nets))
|
| 188 |
ax.set_yticks(range(n_nets))
|
| 189 |
+
ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=9, color="#ccc")
|
| 190 |
+
ax.set_yticklabels(_nn, fontsize=9, color="#ccc")
|
| 191 |
ax.tick_params(colors="#555", length=0)
|
| 192 |
for sp in ax.spines.values():
|
| 193 |
sp.set_color("#222")
|
|
|
|
| 222 |
|
| 223 |
# Callout labels for top-3 cross-network edges
|
| 224 |
for rank, (score, i, j) in enumerate(top3_edges):
|
| 225 |
+
label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
|
| 226 |
ax.annotate(label,
|
| 227 |
xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
|
| 228 |
fontsize=6, color="#fb923c", fontweight="600",
|
|
|
|
| 242 |
|
| 243 |
order = net_imp.argsort()[::-1]
|
| 244 |
bars = ax2.barh(range(n_nets), net_imp[order],
|
| 245 |
+
color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
|
| 246 |
ax2.set_yticks(range(n_nets))
|
| 247 |
+
ax2.set_yticklabels([_nn[i] for i in order], fontsize=9.5, color="#ddd")
|
| 248 |
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
|
| 249 |
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
|
| 250 |
ax2.invert_yaxis()
|
|
|
|
| 299 |
ax3.set_box_aspect([1.2, 1.4, 1.0])
|
| 300 |
|
| 301 |
fig.suptitle(
|
| 302 |
+
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · {len(models)}-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 303 |
color="#444", fontsize=8.5, y=1.02,
|
| 304 |
)
|
| 305 |
plt.tight_layout()
|
|
|
|
| 316 |
return "", "", "", None
|
| 317 |
|
| 318 |
path = Path(file_path)
|
| 319 |
+
atlas_key = "cc200" # default; overridden below for .1D files
|
| 320 |
try:
|
| 321 |
if path.suffix == ".npz":
|
| 322 |
d = np.load(path, allow_pickle=True)
|
|
|
|
| 332 |
adj_t = torch.FloatTensor(adj).unsqueeze(0)
|
| 333 |
else:
|
| 334 |
bold = np.loadtxt(path, dtype=np.float32)
|
| 335 |
+
if bold.ndim != 2:
|
| 336 |
+
return "<div style='color:#ef4444;padding:12px'>Error: file must be a 2D T×ROIs matrix.</div>", "", "", None
|
| 337 |
+
n_rois = bold.shape[1]
|
| 338 |
+
atlas_key = _ROI_TO_ATLAS.get(n_rois)
|
| 339 |
+
if atlas_key is None:
|
| 340 |
+
supported = ", ".join(f"{cfg['label']} ({cfg['n_rois']} ROIs)" for cfg in _ATLAS_CFG.values())
|
| 341 |
+
return (
|
| 342 |
+
f"<div style='background:#1a1015;border-left:3px solid #ef4444;padding:16px 20px;border-radius:8px;margin-top:14px'>"
|
| 343 |
+
f"<div style='color:#ef4444;font-weight:600;margin-bottom:6px'>Unsupported atlas ({n_rois} ROIs)</div>"
|
| 344 |
+
f"<div style='color:#cbd5e1;font-size:0.88rem;line-height:1.6'>"
|
| 345 |
+
f"Supported: {supported}.<br>"
|
| 346 |
+
f"Download from FCP-INDI S3: <code style='color:#fb923c'>rois_cc200/</code>, <code style='color:#fb923c'>rois_aal/</code>, or <code style='color:#fb923c'>rois_ho/</code>"
|
| 347 |
+
f"</div></div>"
|
| 348 |
+
), "", "", None
|
| 349 |
bw_t, adj_t = preprocess(bold)
|
| 350 |
except Exception as e:
|
| 351 |
return f"Error loading file: {e}", "", "", None
|
| 352 |
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
atlas_cfg = _ATLAS_CFG[atlas_key]
|
| 356 |
+
models = get_models(atlas_key)
|
| 357 |
+
|
| 358 |
+
if not models:
|
| 359 |
+
atlas_label = atlas_cfg["label"]
|
| 360 |
+
return (
|
| 361 |
+
f"<div style='background:#1a1015;border-left:3px solid #f59e0b;padding:16px 20px;border-radius:8px;margin-top:14px'>"
|
| 362 |
+
f"<div style='color:#f59e0b;font-weight:600;margin-bottom:6px'>{atlas_label} models not yet available</div>"
|
| 363 |
+
f"<div style='color:#cbd5e1;font-size:0.88rem;line-height:1.6'>"
|
| 364 |
+
f"Training is in progress. CC200 models are available now — convert your data with:<br>"
|
| 365 |
+
f"<code style='color:#fb923c;font-size:0.82rem'>aws s3 cp s3://fcp-indi/.../rois_cc200/ . --no-sign-request --recursive</code>"
|
| 366 |
+
f"</div></div>"
|
| 367 |
+
), "", "", None
|
| 368 |
+
|
| 369 |
per_model = []
|
| 370 |
with torch.no_grad():
|
| 371 |
for site, task in models:
|
|
|
|
| 377 |
conf = max(p_mean, 1 - p_mean) * 100
|
| 378 |
|
| 379 |
try:
|
| 380 |
+
sal_img = _saliency_figure(
|
| 381 |
+
_compute_saliency(bw_t, adj_t, models), p_mean,
|
| 382 |
+
net_names=atlas_cfg["net_names"],
|
| 383 |
+
net_bounds=atlas_cfg["net_bounds"],
|
| 384 |
+
net_colors=atlas_cfg["net_colors"],
|
| 385 |
+
)
|
| 386 |
except Exception:
|
| 387 |
sal_img = None
|
| 388 |
|
|
|
|
| 480 |
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">ICD-10 Classification</div>
|
| 481 |
<div style="color:#cbd5e1;font-size:0.84rem;line-height:1.4">{icd}</div></div>
|
| 482 |
<div><div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:3px">Ensemble Confidence</div>
|
| 483 |
+
<div style="color:#cbd5e1;font-size:0.84rem">{conf:.1f}% · p(ASD) = {p_mean:.3f} · {len(models)}-model LOSO</div></div>
|
| 484 |
</div>
|
| 485 |
|
| 486 |
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:4px;font-weight:500">Impression</div>
|
|
|
|
| 658 |
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:14px 16px;min-width:130px;text-align:center;flex-shrink:0">
|
| 659 |
<div style="color:#8b95a7;font-size:0.65rem;text-transform:uppercase;letter-spacing:1px;margin-bottom:6px">Input</div>
|
| 660 |
<div style="color:#f4f4f5;font-weight:600;font-size:0.88rem">fMRI BOLD</div>
|
| 661 |
+
<div style="color:#5e6675;font-size:0.74rem;margin-top:3px">T × ROIs (CC200/AAL/HO)</div>
|
| 662 |
</div>
|
| 663 |
|
| 664 |
<div style="color:#252a35;font-size:1.4rem;padding:0 6px;flex-shrink:0">→</div>
|
|
|
|
| 719 |
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;overflow:hidden">
|
| 720 |
<table style="width:100%;border-collapse:collapse;font-size:0.85rem">
|
| 721 |
<tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Dataset</td><td style="padding:10px 16px;color:#cbd5e1">ABIDE I · 1,102 subjects · 17 acquisition sites</td></tr>
|
| 722 |
+
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Parcellation</td><td style="padding:10px 16px;color:#cbd5e1">CC200 (200 ROIs) · AAL-116 (116 ROIs) · Harvard-Oxford (111 ROIs)</td></tr>
|
| 723 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Model</td><td style="padding:10px 16px;color:#cbd5e1">AdversarialBrainModeNetwork · K=16 modes · hidden_dim=64</td></tr>
|
| 724 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Validation</td><td style="padding:10px 16px;color:#cbd5e1">LOSO AUC = <span style="color:#ef4444;font-weight:600">0.7872</span> · 529 held-out subjects · 0 confident misclassifications</td></tr>
|
| 725 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Interpretability</td><td style="padding:10px 16px;color:#cbd5e1">Real-time gradient saliency · 7-network aggregation · 3D brain surface</td></tr>
|
|
|
|
| 797 |
|
| 798 |
with gr.Tabs():
|
| 799 |
with gr.Tab("Analysis"):
|
| 800 |
+
file_input = gr.File(label="Upload fMRI time series — CC200 (200), AAL (116), or Harvard-Oxford (111) ROIs · .1D or .npz", type="filepath")
|
| 801 |
gr.HTML("<div style='color:#8b95a7;font-size:0.7rem;text-transform:uppercase;letter-spacing:1.2px;margin:14px 0 8px;font-weight:500'>Or try a real ABIDE subject from a held-out site</div>")
|
| 802 |
with gr.Row():
|
| 803 |
btn_asd = gr.Button("ASD · Stanford 0051160", size="sm")
|
app_with_llm.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BrainConnect-ASD — Scanner-site-invariant ASD detection from fMRI.
|
| 3 |
+
Full pipeline: Adversarial GCN + Qwen2.5-7B fine-tuned on AMD MI300X.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import io
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
_WINDOW_LEN = 50
|
| 15 |
+
_STEP = 3
|
| 16 |
+
_MAX_WINDOWS = 30
|
| 17 |
+
_FC_THRESHOLD = 0.2
|
| 18 |
+
|
| 19 |
+
_CKPTS = {
|
| 20 |
+
"NYU": Path("checkpoints/nyu.ckpt"),
|
| 21 |
+
"USM": Path("checkpoints/usm.ckpt"),
|
| 22 |
+
"UCLA": Path("checkpoints/ucla.ckpt"),
|
| 23 |
+
"UM": Path("checkpoints/um.ckpt"),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
_LLM_MODEL = "Yatsuiii/asd-interpreter-lora"
|
| 27 |
+
|
| 28 |
+
SYSTEM_PROMPT = (
|
| 29 |
+
"You are a clinical AI assistant specializing in functional MRI brain "
|
| 30 |
+
"connectivity analysis for autism spectrum disorder (ASD) diagnosis support. "
|
| 31 |
+
"You interpret outputs from a validated graph neural network (GCN) trained on "
|
| 32 |
+
"the ABIDE I dataset and provide structured clinical summaries for neurologists "
|
| 33 |
+
"and psychiatrists. Your reports are informative and evidence-based but always "
|
| 34 |
+
"clarify that findings are AI-assisted and should be integrated with full "
|
| 35 |
+
"clinical assessment. You do not make a diagnosis."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# ── preprocessing ──────────────────────────────────────────────────────────
|
| 39 |
+
|
| 40 |
+
def _zscore(bold):
|
| 41 |
+
mean = bold.mean(0, keepdims=True)
|
| 42 |
+
std = bold.std(0, keepdims=True)
|
| 43 |
+
std[std < 1e-8] = 1.0
|
| 44 |
+
return ((bold - mean) / std).astype(np.float32)
|
| 45 |
+
|
| 46 |
+
def _fc(bold):
|
| 47 |
+
fc = np.corrcoef(bold.T).astype(np.float32)
|
| 48 |
+
np.nan_to_num(fc, copy=False)
|
| 49 |
+
return fc
|
| 50 |
+
|
| 51 |
+
def _windows(bold):
|
| 52 |
+
T, N = bold.shape
|
| 53 |
+
starts = list(range(0, T - _WINDOW_LEN + 1, _STEP))
|
| 54 |
+
w = np.stack([bold[s:s+_WINDOW_LEN].std(0) for s in starts]).astype(np.float32)
|
| 55 |
+
if len(w) >= _MAX_WINDOWS:
|
| 56 |
+
return w[:_MAX_WINDOWS]
|
| 57 |
+
return np.concatenate([w, np.repeat(w[-1:], _MAX_WINDOWS - len(w), 0)])
|
| 58 |
+
|
| 59 |
+
def preprocess(bold):
|
| 60 |
+
bold = _zscore(bold)
|
| 61 |
+
fc = _fc(bold)
|
| 62 |
+
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
|
| 63 |
+
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
|
| 64 |
+
bw = _windows(bold)
|
| 65 |
+
return torch.FloatTensor(bw).unsqueeze(0), torch.FloatTensor(adj).unsqueeze(0)
|
| 66 |
+
|
| 67 |
+
# ── GCN model loading ──────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
_models: list | None = None
|
| 70 |
+
|
| 71 |
+
def get_models():
|
| 72 |
+
global _models
|
| 73 |
+
if _models is not None:
|
| 74 |
+
return _models
|
| 75 |
+
from brain_gcn.tasks import ClassificationTask
|
| 76 |
+
_models = []
|
| 77 |
+
for site, ckpt in _CKPTS.items():
|
| 78 |
+
if not ckpt.exists():
|
| 79 |
+
continue
|
| 80 |
+
task = ClassificationTask.load_from_checkpoint(str(ckpt), map_location="cpu", strict=False)
|
| 81 |
+
task.eval()
|
| 82 |
+
_models.append((site, task))
|
| 83 |
+
return _models
|
| 84 |
+
|
| 85 |
+
# ── LLM loading ────────────────────────────────────────────────────────────
|
| 86 |
+
|
| 87 |
+
_llm = None
|
| 88 |
+
|
| 89 |
+
def get_llm():
|
| 90 |
+
global _llm
|
| 91 |
+
if _llm is not None:
|
| 92 |
+
return _llm
|
| 93 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 94 |
+
print(f"Loading LLM: {_LLM_MODEL}")
|
| 95 |
+
tok = AutoTokenizer.from_pretrained(_LLM_MODEL)
|
| 96 |
+
tok.pad_token = tok.eos_token
|
| 97 |
+
mdl = AutoModelForCausalLM.from_pretrained(
|
| 98 |
+
_LLM_MODEL,
|
| 99 |
+
torch_dtype=torch.bfloat16,
|
| 100 |
+
device_map="auto",
|
| 101 |
+
)
|
| 102 |
+
mdl.eval()
|
| 103 |
+
_llm = (mdl, tok)
|
| 104 |
+
return _llm
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _llm_report(p_mean: float, per_model: list) -> str:
|
| 108 |
+
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 109 |
+
per_model_str = "\n".join(
|
| 110 |
+
f" {s}-blind: {'ASD' if v > 0.5 else 'TC'} (p={v:.3f})" for s, v in per_model
|
| 111 |
+
)
|
| 112 |
+
conf_label = (
|
| 113 |
+
"HIGH" if p_mean >= 0.75 else
|
| 114 |
+
"MODERATE" if p_mean >= 0.6 else
|
| 115 |
+
"LOW / UNCERTAIN" if p_mean >= 0.4 else
|
| 116 |
+
"MODERATE (TC)" if p_mean >= 0.25 else
|
| 117 |
+
"HIGH (TC)"
|
| 118 |
+
)
|
| 119 |
+
user_msg = (
|
| 120 |
+
f"Brain Connectivity GCN Analysis Report\n"
|
| 121 |
+
f"{'='*40}\n"
|
| 122 |
+
f"p(ASD) : {p_mean:.3f}\n"
|
| 123 |
+
f"Confidence Level : {conf_label}\n"
|
| 124 |
+
f"Model Consensus : {consensus}/4 site-blind models predict ASD\n\n"
|
| 125 |
+
f"Per-Model Breakdown (LOSO ensemble):\n{per_model_str}\n\n"
|
| 126 |
+
f"Please provide a structured clinical interpretation of these findings."
|
| 127 |
+
)
|
| 128 |
+
try:
|
| 129 |
+
mdl, tok = get_llm()
|
| 130 |
+
messages = [
|
| 131 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 132 |
+
{"role": "user", "content": user_msg},
|
| 133 |
+
]
|
| 134 |
+
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 135 |
+
inputs = tok(text, return_tensors="pt").to(next(mdl.parameters()).device)
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
out = mdl.generate(
|
| 138 |
+
**inputs,
|
| 139 |
+
max_new_tokens=512,
|
| 140 |
+
temperature=0.3,
|
| 141 |
+
do_sample=True,
|
| 142 |
+
pad_token_id=tok.eos_token_id,
|
| 143 |
+
)
|
| 144 |
+
generated = out[0][inputs["input_ids"].shape[1]:]
|
| 145 |
+
return tok.decode(generated, skip_special_tokens=True).strip()
|
| 146 |
+
except Exception as e:
|
| 147 |
+
return f"LLM unavailable: {e}"
|
| 148 |
+
|
| 149 |
+
# ── gradient saliency ──────────────────────────────────────────────────────
|
| 150 |
+
|
| 151 |
+
def _compute_saliency(bw_t: torch.Tensor, adj_t: torch.Tensor, models) -> np.ndarray:
|
| 152 |
+
maps = []
|
| 153 |
+
for _, task in models:
|
| 154 |
+
adj = adj_t.clone().requires_grad_(True)
|
| 155 |
+
logits = task.model(bw_t, adj)
|
| 156 |
+
p = torch.softmax(logits, -1)[0, 1]
|
| 157 |
+
p.backward()
|
| 158 |
+
maps.append(adj.grad[0].abs().detach().numpy())
|
| 159 |
+
sal = np.mean(maps, axis=0)
|
| 160 |
+
sal = (sal + sal.T) / 2
|
| 161 |
+
return sal
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _saliency_figure(sal: np.ndarray, p_mean: float):
|
| 165 |
+
import matplotlib
|
| 166 |
+
matplotlib.use("Agg")
|
| 167 |
+
import matplotlib.pyplot as plt
|
| 168 |
+
from PIL import Image
|
| 169 |
+
|
| 170 |
+
thresh = np.percentile(sal, 95)
|
| 171 |
+
sal_top = np.where(sal >= thresh, sal, 0.0)
|
| 172 |
+
roi_imp = sal.sum(1)
|
| 173 |
+
top20 = roi_imp.argsort()[-20:][::-1]
|
| 174 |
+
verdict_color = "#e63946" if p_mean > 0.6 else "#2dc653" if p_mean < 0.4 else "#f4a261"
|
| 175 |
+
|
| 176 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))
|
| 177 |
+
fig.patch.set_facecolor("#0d0d0d")
|
| 178 |
+
|
| 179 |
+
ax = axes[0]
|
| 180 |
+
ax.set_facecolor("#111")
|
| 181 |
+
im = ax.imshow(sal_top, cmap="inferno", aspect="auto", interpolation="nearest")
|
| 182 |
+
ax.set_title("FC Edge Saliency (top 5% connections)", color="#ccc", fontsize=11, pad=10)
|
| 183 |
+
ax.set_xlabel("ROI index", color="#777", fontsize=9)
|
| 184 |
+
ax.set_ylabel("ROI index", color="#777", fontsize=9)
|
| 185 |
+
ax.tick_params(colors="#555", labelsize=8)
|
| 186 |
+
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 187 |
+
cb.ax.yaxis.set_tick_params(color="#555", labelsize=7)
|
| 188 |
+
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#666")
|
| 189 |
+
for spine in ax.spines.values():
|
| 190 |
+
spine.set_color("#333")
|
| 191 |
+
|
| 192 |
+
ax2 = axes[1]
|
| 193 |
+
ax2.set_facecolor("#111")
|
| 194 |
+
ax2.barh(range(20), roi_imp[top20], color=verdict_color, alpha=0.75, edgecolor="none")
|
| 195 |
+
ax2.set_yticks(range(20))
|
| 196 |
+
ax2.set_yticklabels([f"ROI {i:03d}" for i in top20], fontsize=8, color="#ccc")
|
| 197 |
+
ax2.set_xlabel("Cumulative gradient magnitude", color="#777", fontsize=9)
|
| 198 |
+
ax2.set_title("Top-20 ROIs by Prediction Influence", color="#ccc", fontsize=11, pad=10)
|
| 199 |
+
ax2.tick_params(colors="#555", labelsize=8)
|
| 200 |
+
ax2.invert_yaxis()
|
| 201 |
+
for spine in ["top", "right"]:
|
| 202 |
+
ax2.spines[spine].set_visible(False)
|
| 203 |
+
for spine in ["bottom", "left"]:
|
| 204 |
+
ax2.spines[spine].set_color("#333")
|
| 205 |
+
|
| 206 |
+
fig.suptitle(
|
| 207 |
+
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · Ensemble of {len(_models)} LOSO models",
|
| 208 |
+
color="#888", fontsize=10, y=1.02,
|
| 209 |
+
)
|
| 210 |
+
plt.tight_layout()
|
| 211 |
+
buf = io.BytesIO()
|
| 212 |
+
plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0d0d0d")
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
buf.seek(0)
|
| 215 |
+
return Image.open(buf).copy()
|
| 216 |
+
|
| 217 |
+
# ── inference ──────────────────────────────────────────────────────────────
|
| 218 |
+
|
| 219 |
+
def run_gcn(file_path: str | None):
|
| 220 |
+
if file_path is None:
|
| 221 |
+
return "", "", "", None, ""
|
| 222 |
+
|
| 223 |
+
path = Path(file_path)
|
| 224 |
+
try:
|
| 225 |
+
if path.suffix == ".npz":
|
| 226 |
+
d = np.load(path, allow_pickle=True)
|
| 227 |
+
fc = d["mean_fc"].astype(np.float32)
|
| 228 |
+
fc = np.arctanh(np.clip(fc, -0.9999, 0.9999))
|
| 229 |
+
adj = np.where(np.abs(fc) >= _FC_THRESHOLD, fc, 0.0).astype(np.float32)
|
| 230 |
+
bw = d["bold_windows"].astype(np.float32)
|
| 231 |
+
if len(bw) >= _MAX_WINDOWS:
|
| 232 |
+
bw = bw[:_MAX_WINDOWS]
|
| 233 |
+
else:
|
| 234 |
+
bw = np.concatenate([bw, np.repeat(bw[-1:], _MAX_WINDOWS - len(bw), 0)])
|
| 235 |
+
bw_t = torch.FloatTensor(bw).unsqueeze(0)
|
| 236 |
+
adj_t = torch.FloatTensor(adj).unsqueeze(0)
|
| 237 |
+
else:
|
| 238 |
+
bold = np.loadtxt(path, dtype=np.float32)
|
| 239 |
+
if bold.ndim != 2 or bold.shape[1] != 200:
|
| 240 |
+
return f"⚠️ Error: expected (T×200) array, got {bold.shape}", "", "", None, ""
|
| 241 |
+
bw_t, adj_t = preprocess(bold)
|
| 242 |
+
except Exception as e:
|
| 243 |
+
return f"⚠️ Error loading file: {e}", "", "", None, ""
|
| 244 |
+
|
| 245 |
+
models = get_models()
|
| 246 |
+
|
| 247 |
+
per_model = []
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
for site, task in models:
|
| 250 |
+
logits = task(bw_t, adj_t)
|
| 251 |
+
p = torch.softmax(logits, -1)[0, 1].item()
|
| 252 |
+
per_model.append((site, p))
|
| 253 |
+
|
| 254 |
+
p_mean = float(np.mean([p for _, p in per_model]))
|
| 255 |
+
consensus = sum(1 for _, p in per_model if p > 0.5)
|
| 256 |
+
conf = max(p_mean, 1 - p_mean) * 100
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
sal = _compute_saliency(bw_t, adj_t, models)
|
| 260 |
+
sal_img = _saliency_figure(sal, p_mean)
|
| 261 |
+
except Exception:
|
| 262 |
+
sal_img = None
|
| 263 |
+
|
| 264 |
+
# Verdict
|
| 265 |
+
if p_mean > 0.6:
|
| 266 |
+
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #e63946;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 267 |
+
<div style="font-size:2rem;font-weight:800;color:#e63946;letter-spacing:1px">ASD INDICATED</div>
|
| 268 |
+
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{consensus}/4</b> site-blind models agree</div>
|
| 269 |
+
</div>"""
|
| 270 |
+
elif p_mean < 0.4:
|
| 271 |
+
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #2dc653;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 272 |
+
<div style="font-size:2rem;font-weight:800;color:#2dc653;letter-spacing:1px">TYPICAL CONTROL</div>
|
| 273 |
+
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | <b style="color:white">{4-consensus}/4</b> site-blind models agree</div>
|
| 274 |
+
</div>"""
|
| 275 |
+
else:
|
| 276 |
+
verdict = f"""<div style="background:#1a1a2e;border-left:6px solid #f4a261;padding:24px 28px;border-radius:12px;margin-bottom:8px">
|
| 277 |
+
<div style="font-size:2rem;font-weight:800;color:#f4a261;letter-spacing:1px">INCONCLUSIVE</div>
|
| 278 |
+
<div style="font-size:1.1rem;color:#aaa;margin-top:6px">Confidence: <b style="color:white">{conf:.1f}%</b> | p(ASD) = <b style="color:white">{p_mean:.3f}</b> | Model disagreement — clinical review required</div>
|
| 279 |
+
</div>"""
|
| 280 |
+
|
| 281 |
+
# Ensemble breakdown
|
| 282 |
+
rows = ""
|
| 283 |
+
for site, p in per_model:
|
| 284 |
+
lbl = "ASD" if p > 0.5 else "TC"
|
| 285 |
+
color = "#e63946" if p > 0.5 else "#2dc653"
|
| 286 |
+
bar_w = int(p * 100)
|
| 287 |
+
rows += f"""<tr>
|
| 288 |
+
<td style="padding:8px 12px;color:#ccc;font-weight:600">{site}-blind</td>
|
| 289 |
+
<td style="padding:8px 12px"><div style="background:#333;border-radius:4px;height:18px;width:160px">
|
| 290 |
+
<div style="background:{color};height:18px;width:{bar_w}%;border-radius:4px;opacity:0.85"></div></div></td>
|
| 291 |
+
<td style="padding:8px 12px;color:{color};font-weight:700">{lbl}</td>
|
| 292 |
+
<td style="padding:8px 12px;color:#888">p={p:.3f}</td>
|
| 293 |
+
</tr>"""
|
| 294 |
+
|
| 295 |
+
ensemble = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
|
| 296 |
+
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Leave-One-Site-Out Ensemble — each model never trained on that site's data</div>
|
| 297 |
+
<table style="width:100%;border-collapse:collapse">{rows}</table>
|
| 298 |
+
<div style="margin-top:14px;color:#666;font-size:0.82rem">Cross-site consensus: {consensus}/4 models agree · LOSO AUC = 0.7872 across 529 held-out subjects</div>
|
| 299 |
+
</div>"""
|
| 300 |
+
|
| 301 |
+
# LLM clinical report
|
| 302 |
+
llm_text = _llm_report(p_mean, per_model)
|
| 303 |
+
report = f"""<div style="background:#111;border-radius:10px;padding:20px;margin-top:4px">
|
| 304 |
+
<div style="color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin-bottom:14px">Clinical Report — Qwen2.5-7B fine-tuned on AMD Instinct MI300X</div>
|
| 305 |
+
<div style="color:#ddd;font-size:0.95rem;line-height:1.7;white-space:pre-wrap">{llm_text}</div>
|
| 306 |
+
<div style="background:#1a1a1a;border-radius:6px;padding:12px;color:#555;font-size:0.78rem;margin-top:16px">
|
| 307 |
+
⚕️ AI-assisted analysis only. Does not constitute a diagnosis. Integrate with clinical history, behavioral assessment, and standardized instruments (ADOS-2, ADI-R).
|
| 308 |
+
</div></div>"""
|
| 309 |
+
|
| 310 |
+
return verdict, ensemble, report, sal_img
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# ── UI ─────────────────────────────────────────────────────────────────────
|
| 314 |
+
|
| 315 |
+
css = """
|
| 316 |
+
body { background: #0d0d0d; }
|
| 317 |
+
.gradio-container { max-width: 960px; margin: auto; }
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
with gr.Blocks(title="BrainConnect-ASD", css=css, theme=gr.themes.Base()) as demo:
|
| 321 |
+
gr.HTML("""
|
| 322 |
+
<div style="text-align:center;padding:32px 0 16px">
|
| 323 |
+
<div style="font-size:2.2rem;font-weight:900;color:white;letter-spacing:-1px">BrainConnect<span style="color:#e63946">-ASD</span></div>
|
| 324 |
+
<div style="color:#888;font-size:1rem;margin-top:8px">Scanner-site-invariant ASD detection from resting-state fMRI</div>
|
| 325 |
+
<div style="display:flex;justify-content:center;gap:24px;margin-top:16px;flex-wrap:wrap">
|
| 326 |
+
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">LOSO AUC 0.7872</span>
|
| 327 |
+
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">529 held-out subjects</span>
|
| 328 |
+
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">4 independent institutions</span>
|
| 329 |
+
<span style="background:#1a1a2e;color:#aaa;padding:6px 14px;border-radius:20px;font-size:0.85rem">AMD Instinct MI300X</span>
|
| 330 |
+
</div>
|
| 331 |
+
</div>
|
| 332 |
+
""")
|
| 333 |
+
|
| 334 |
+
file_input = gr.File(label="Upload CC200 fMRI file (.1D or .npz)", type="filepath")
|
| 335 |
+
verdict_html = gr.HTML()
|
| 336 |
+
ensemble_html = gr.HTML()
|
| 337 |
+
|
| 338 |
+
gr.HTML("<div style='color:#888;font-size:0.8rem;text-transform:uppercase;letter-spacing:2px;margin:24px 0 8px'>Gradient Saliency — which brain connections drove this prediction</div>")
|
| 339 |
+
saliency_img = gr.Image(label="FC Edge Saliency & ROI Importance", type="pil")
|
| 340 |
+
|
| 341 |
+
report_html = gr.HTML()
|
| 342 |
+
|
| 343 |
+
file_input.change(
|
| 344 |
+
fn=run_gcn,
|
| 345 |
+
inputs=file_input,
|
| 346 |
+
outputs=[verdict_html, ensemble_html, report_html, saliency_img],
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
gr.HTML("""
|
| 350 |
+
<div style="text-align:center;padding:24px 0;color:#444;font-size:0.8rem">
|
| 351 |
+
Adversarial Brain-Mode GCN (k=16) · Qwen2.5-7B LoRA (AMD MI300X) · ABIDE I ·
|
| 352 |
+
<a href="https://github.com/Yatsuiii/Brain-Connectivity-GCN" style="color:#666">GitHub</a>
|
| 353 |
+
</div>
|
| 354 |
+
""")
|
| 355 |
+
|
| 356 |
+
print("Preloading GCN models...")
|
| 357 |
+
get_models()
|
| 358 |
+
print("Preloading LLM...")
|
| 359 |
+
get_llm()
|
| 360 |
+
print("All models ready.")
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
demo.launch()
|
checkpoints/aal_nyu.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d588597092f4b9483fdcfb0f8d5aae67030ea9d17df8cf7a0c027ef527b5657
|
| 3 |
+
size 253386
|
checkpoints/aal_ucla.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3af986f0e5909e8551c2473ee3ff876ff9cc814ea38e220de33fad96201eaa37
|
| 3 |
+
size 253386
|
checkpoints/aal_um.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b380df053b6a9a83dc6062070418cd6b8718b4f4c5c5f8b388330dcf0a9abf5
|
| 3 |
+
size 253386
|
checkpoints/aal_usm.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7410692b93237e59fe5a03726338b7f02f0b61acbb6185dc1ecc79dfbee6bdac
|
| 3 |
+
size 252813
|