Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
| 11 |
import torch
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
-
from _charts import VAL_B64, AUC_B64
|
| 15 |
|
| 16 |
_WINDOW_LEN = 50
|
| 17 |
_STEP = 3
|
|
@@ -378,24 +378,27 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 378 |
for s, e in zip(_nb[:-1], _nb[1:])
|
| 379 |
])
|
| 380 |
|
| 381 |
-
fig = plt.figure(figsize=(
|
| 382 |
fig.patch.set_facecolor("#0e1015")
|
|
|
|
|
|
|
|
|
|
| 383 |
axes = [
|
| 384 |
-
fig.add_subplot(
|
| 385 |
-
fig.add_subplot(
|
| 386 |
-
fig.add_subplot(1,
|
| 387 |
]
|
| 388 |
|
| 389 |
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
|
| 390 |
ax = axes[0]
|
| 391 |
ax.set_facecolor("#161922")
|
| 392 |
im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
|
| 393 |
-
ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=
|
| 394 |
|
| 395 |
ax.set_xticks(range(n_nets))
|
| 396 |
ax.set_yticks(range(n_nets))
|
| 397 |
-
ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=
|
| 398 |
-
ax.set_yticklabels(_nn, fontsize=
|
| 399 |
ax.tick_params(colors="#555", length=0)
|
| 400 |
for sp in ax.spines.values():
|
| 401 |
sp.set_color("#222")
|
|
@@ -421,7 +424,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 421 |
for j in range(n_nets):
|
| 422 |
txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666"
|
| 423 |
ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
|
| 424 |
-
fontsize=
|
| 425 |
if (i, j) in top5_cells:
|
| 426 |
rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96,
|
| 427 |
linewidth=1.8, edgecolor="#ffffff",
|
|
@@ -433,7 +436,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 433 |
label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
|
| 434 |
ax.annotate(label,
|
| 435 |
xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
|
| 436 |
-
fontsize=
|
| 437 |
arrowprops=dict(arrowstyle="-", color="#fb923c",
|
| 438 |
lw=0.7, connectionstyle="arc3,rad=0.1"),
|
| 439 |
ha="left", va="center", zorder=5)
|
|
@@ -441,7 +444,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 441 |
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 442 |
cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
|
| 443 |
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
|
| 444 |
-
cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=
|
| 445 |
|
| 446 |
# ── Right: network importance bar chart ────────────────────────────────
|
| 447 |
ax2 = axes[1]
|
|
@@ -452,9 +455,9 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 452 |
bars = ax2.barh(range(n_nets), net_imp[order],
|
| 453 |
color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
|
| 454 |
ax2.set_yticks(range(n_nets))
|
| 455 |
-
ax2.set_yticklabels([_nn[i] for i in order], fontsize=
|
| 456 |
-
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=
|
| 457 |
-
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=
|
| 458 |
ax2.invert_yaxis()
|
| 459 |
for sp in ["top", "right"]:
|
| 460 |
ax2.spines[sp].set_visible(False)
|
|
@@ -465,14 +468,14 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 465 |
x_max = net_imp.max()
|
| 466 |
for bar, val in zip(bars, net_imp[order]):
|
| 467 |
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
|
| 468 |
-
f"{val:.4f}", va="center", color="#555", fontsize=
|
| 469 |
|
| 470 |
# ── 3D Brain Surface — top connections ────────────────────────────────────
|
| 471 |
ax3 = axes[2]
|
| 472 |
ax3.set_facecolor("#0e1015")
|
| 473 |
ax3.grid(False)
|
| 474 |
ax3.set_axis_off()
|
| 475 |
-
ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=
|
| 476 |
|
| 477 |
# Transparent brain ellipsoid wireframe (MNI space approx)
|
| 478 |
u = np.linspace(0, 2 * np.pi, 32)
|
|
@@ -489,7 +492,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 489 |
size = 60 + imp_norm[k] * 260
|
| 490 |
ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
|
| 491 |
edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
|
| 492 |
-
ax3.text(x, y, z + 7, name, fontsize=
|
| 493 |
ha="center", va="bottom", fontweight="600", zorder=6)
|
| 494 |
|
| 495 |
# Draw top-5 inter-network connections as lines, thickness ∝ saliency
|
|
@@ -508,11 +511,10 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
|
|
| 508 |
|
| 509 |
fig.suptitle(
|
| 510 |
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · 20-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 511 |
-
color="#
|
| 512 |
)
|
| 513 |
-
plt.tight_layout()
|
| 514 |
buf = io.BytesIO()
|
| 515 |
-
plt.savefig(buf, format="png", dpi=
|
| 516 |
plt.close(fig)
|
| 517 |
buf.seek(0)
|
| 518 |
return Image.open(buf).copy()
|
|
@@ -1008,35 +1010,50 @@ ARCHITECTURE = """
|
|
| 1008 |
</div>
|
| 1009 |
"""
|
| 1010 |
|
| 1011 |
-
AMD =
|
| 1012 |
<div>
|
| 1013 |
|
| 1014 |
-
<!--
|
| 1015 |
-
<
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
<div
|
| 1021 |
-
|
| 1022 |
-
<div style="
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
</div>
|
|
|
|
|
|
|
| 1028 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
<div
|
| 1036 |
-
<div
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1037 |
</div>
|
| 1038 |
</div>
|
| 1039 |
-
|
| 1040 |
</div>
|
| 1041 |
|
| 1042 |
<!-- Fine-tune spec table -->
|
|
@@ -1044,7 +1061,8 @@ AMD = f"""
|
|
| 1044 |
<table style="width:100%;border-collapse:collapse;font-size:0.85rem">
|
| 1045 |
<tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Base model</td><td style="padding:10px 16px;color:#cbd5e1">Qwen/Qwen2.5-7B-Instruct <span style="color:#5e6675">· AMD partner model · ROCm native</span></td></tr>
|
| 1046 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Method</td><td style="padding:10px 16px;color:#cbd5e1">LoRA r=16 α=32 · q, k, v, o, gate, up, down projections · bf16 — no quantization needed</td></tr>
|
| 1047 |
-
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">
|
|
|
|
| 1048 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Why MI300X?</td><td style="padding:10px 16px;color:#cbd5e1">192 GB unified HBM3 fits the full 7B model in bf16 without sharding — impossible on consumer GPUs. ROCm enables native PyTorch training with zero code changes.</td></tr>
|
| 1049 |
</table>
|
| 1050 |
</div>
|
|
|
|
| 11 |
import torch
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
+
from _charts import VAL_B64, AUC_B64
|
| 15 |
|
| 16 |
_WINDOW_LEN = 50
|
| 17 |
_STEP = 3
|
|
|
|
| 378 |
for s, e in zip(_nb[:-1], _nb[1:])
|
| 379 |
])
|
| 380 |
|
| 381 |
+
fig = plt.figure(figsize=(20, 18))
|
| 382 |
fig.patch.set_facecolor("#0e1015")
|
| 383 |
+
import matplotlib.gridspec as gridspec
|
| 384 |
+
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.38, wspace=0.32,
|
| 385 |
+
height_ratios=[1.1, 1.0])
|
| 386 |
axes = [
|
| 387 |
+
fig.add_subplot(gs[0, 0]), # heatmap
|
| 388 |
+
fig.add_subplot(gs[0, 1]), # bar chart
|
| 389 |
+
fig.add_subplot(gs[1, :], projection="3d"), # 3D brain — full bottom row
|
| 390 |
]
|
| 391 |
|
| 392 |
# ── Left: 7×7 network heatmap ──────────────────────────────────────────
|
| 393 |
ax = axes[0]
|
| 394 |
ax.set_facecolor("#161922")
|
| 395 |
im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
|
| 396 |
+
ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=14, pad=16, fontweight="bold")
|
| 397 |
|
| 398 |
ax.set_xticks(range(n_nets))
|
| 399 |
ax.set_yticks(range(n_nets))
|
| 400 |
+
ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=12, color="#ccc")
|
| 401 |
+
ax.set_yticklabels(_nn, fontsize=12, color="#ccc")
|
| 402 |
ax.tick_params(colors="#555", length=0)
|
| 403 |
for sp in ax.spines.values():
|
| 404 |
sp.set_color("#222")
|
|
|
|
| 424 |
for j in range(n_nets):
|
| 425 |
txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666"
|
| 426 |
ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
|
| 427 |
+
fontsize=9, color=txt_color, zorder=3)
|
| 428 |
if (i, j) in top5_cells:
|
| 429 |
rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96,
|
| 430 |
linewidth=1.8, edgecolor="#ffffff",
|
|
|
|
| 436 |
label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
|
| 437 |
ax.annotate(label,
|
| 438 |
xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
|
| 439 |
+
fontsize=8.5, color="#fb923c", fontweight="600",
|
| 440 |
arrowprops=dict(arrowstyle="-", color="#fb923c",
|
| 441 |
lw=0.7, connectionstyle="arc3,rad=0.1"),
|
| 442 |
ha="left", va="center", zorder=5)
|
|
|
|
| 444 |
cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 445 |
cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
|
| 446 |
plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
|
| 447 |
+
cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=10)
|
| 448 |
|
| 449 |
# ── Right: network importance bar chart ────────────────────────────────
|
| 450 |
ax2 = axes[1]
|
|
|
|
| 455 |
bars = ax2.barh(range(n_nets), net_imp[order],
|
| 456 |
color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
|
| 457 |
ax2.set_yticks(range(n_nets))
|
| 458 |
+
ax2.set_yticklabels([_nn[i] for i in order], fontsize=12, color="#ddd")
|
| 459 |
+
ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=11)
|
| 460 |
+
ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=14, pad=16, fontweight="bold")
|
| 461 |
ax2.invert_yaxis()
|
| 462 |
for sp in ["top", "right"]:
|
| 463 |
ax2.spines[sp].set_visible(False)
|
|
|
|
| 468 |
x_max = net_imp.max()
|
| 469 |
for bar, val in zip(bars, net_imp[order]):
|
| 470 |
ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
|
| 471 |
+
f"{val:.4f}", va="center", color="#555", fontsize=10)
|
| 472 |
|
| 473 |
# ── 3D Brain Surface — top connections ────────────────────────────────────
|
| 474 |
ax3 = axes[2]
|
| 475 |
ax3.set_facecolor("#0e1015")
|
| 476 |
ax3.grid(False)
|
| 477 |
ax3.set_axis_off()
|
| 478 |
+
ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=14, pad=8, fontweight="bold")
|
| 479 |
|
| 480 |
# Transparent brain ellipsoid wireframe (MNI space approx)
|
| 481 |
u = np.linspace(0, 2 * np.pi, 32)
|
|
|
|
| 492 |
size = 60 + imp_norm[k] * 260
|
| 493 |
ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
|
| 494 |
edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
|
| 495 |
+
ax3.text(x, y, z + 7, name, fontsize=8, color=color,
|
| 496 |
ha="center", va="bottom", fontweight="600", zorder=6)
|
| 497 |
|
| 498 |
# Draw top-5 inter-network connections as lines, thickness ∝ saliency
|
|
|
|
| 511 |
|
| 512 |
fig.suptitle(
|
| 513 |
f"Gradient Saliency · p(ASD) = {p_mean:.3f} · 20-model LOSO ensemble · CC200 → Yeo-7 networks",
|
| 514 |
+
color="#888", fontsize=12, y=1.01,
|
| 515 |
)
|
|
|
|
| 516 |
buf = io.BytesIO()
|
| 517 |
+
plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0e1015")
|
| 518 |
plt.close(fig)
|
| 519 |
buf.seek(0)
|
| 520 |
return Image.open(buf).copy()
|
|
|
|
| 1010 |
</div>
|
| 1011 |
"""
|
| 1012 |
|
| 1013 |
+
AMD = """
|
| 1014 |
<div>
|
| 1015 |
|
| 1016 |
+
<!-- Stat grid: real benchmarked numbers -->
|
| 1017 |
+
<div style="display:grid;grid-template-columns:repeat(4,1fr);gap:14px;margin-bottom:18px">
|
| 1018 |
+
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
|
| 1019 |
+
<div style="font-size:1.8rem;font-weight:700;color:#fb923c;font-variant-numeric:tabular-nums">~20<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> ms</span></div>
|
| 1020 |
+
<div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">End-to-end per patient</div>
|
| 1021 |
+
<div style="color:#5e6675;font-size:0.64rem;margin-top:3px">preprocess + 20-model ensemble</div>
|
| 1022 |
+
</div>
|
| 1023 |
+
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
|
| 1024 |
+
<div style="font-size:1.8rem;font-weight:700;color:#fb923c">192<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> GB</span></div>
|
| 1025 |
+
<div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">HBM3 unified memory</div>
|
| 1026 |
+
<div style="color:#5e6675;font-size:0.64rem;margin-top:3px">7B model fits in bf16, no sharding</div>
|
| 1027 |
+
</div>
|
| 1028 |
+
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
|
| 1029 |
+
<div style="font-size:1.8rem;font-weight:700;color:#fb923c">60</div>
|
| 1030 |
+
<div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">Models trained on MI300X</div>
|
| 1031 |
+
<div style="color:#5e6675;font-size:0.64rem;margin-top:3px">20 folds × 3 atlases · 150 epochs each</div>
|
| 1032 |
</div>
|
| 1033 |
+
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
|
| 1034 |
+
<div style="font-size:1.8rem;font-weight:700;color:#fb923c">ROCm<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> 7.0</span></div>
|
| 1035 |
+
<div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">Native PyTorch stack</div>
|
| 1036 |
+
<div style="color:#5e6675;font-size:0.64rem;margin-top:3px">zero code changes from CUDA</div>
|
| 1037 |
+
</div>
|
| 1038 |
+
</div>
|
| 1039 |
|
| 1040 |
+
<!-- How AMD is used -->
|
| 1041 |
+
<div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px;margin-bottom:14px">
|
| 1042 |
+
<div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">AMD MI300X Usage in This Project</div>
|
| 1043 |
+
<div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:16px">
|
| 1044 |
+
<div>
|
| 1045 |
+
<div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px">① GCN Training</div>
|
| 1046 |
+
<div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">60 site-holdout models trained via PyTorch Lightning on ROCm 7.0. K=32 modes, 150 epochs, 3 atlases (CC200, AAL, HO).</div>
|
| 1047 |
+
</div>
|
| 1048 |
+
<div>
|
| 1049 |
+
<div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px">② LLM Fine-Tuning</div>
|
| 1050 |
+
<div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">Qwen2.5-7B fine-tuned with LoRA (r=16, bf16) on 2K domain examples. 192GB HBM3 fits the full model without sharding.</div>
|
| 1051 |
+
</div>
|
| 1052 |
+
<div>
|
| 1053 |
+
<div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px">③ Live LLM Inference</div>
|
| 1054 |
+
<div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">Clinical reports generated in real-time via vLLM on the MI300X droplet. Every report you see is live AMD inference.</div>
|
| 1055 |
</div>
|
| 1056 |
</div>
|
|
|
|
| 1057 |
</div>
|
| 1058 |
|
| 1059 |
<!-- Fine-tune spec table -->
|
|
|
|
| 1061 |
<table style="width:100%;border-collapse:collapse;font-size:0.85rem">
|
| 1062 |
<tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Base model</td><td style="padding:10px 16px;color:#cbd5e1">Qwen/Qwen2.5-7B-Instruct <span style="color:#5e6675">· AMD partner model · ROCm native</span></td></tr>
|
| 1063 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Method</td><td style="padding:10px 16px;color:#cbd5e1">LoRA r=16 α=32 · q, k, v, o, gate, up, down projections · bf16 — no quantization needed</td></tr>
|
| 1064 |
+
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">GCN inference</td><td style="padding:10px 16px;color:#cbd5e1">~20ms end-to-end per patient · benchmarked on AMD MI300X · ROCm 7.0 · PyTorch 2.5.1</td></tr>
|
| 1065 |
+
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">LLM serving</td><td style="padding:10px 16px;color:#cbd5e1">vLLM on AMD MI300X · OpenAI-compatible API · live inference for every clinical report</td></tr>
|
| 1066 |
<tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Why MI300X?</td><td style="padding:10px 16px;color:#cbd5e1">192 GB unified HBM3 fits the full 7B model in bf16 without sharding — impossible on consumer GPUs. ROCm enables native PyTorch training with zero code changes.</td></tr>
|
| 1067 |
</table>
|
| 1068 |
</div>
|