Yatsuiii commited on
Commit
786af16
·
verified ·
1 Parent(s): 8e0b796

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +62 -44
app.py CHANGED
@@ -11,7 +11,7 @@ import numpy as np
11
  import torch
12
  import gradio as gr
13
 
14
- from _charts import VAL_B64, AUC_B64, AMD_BENCH_B64
15
 
16
  _WINDOW_LEN = 50
17
  _STEP = 3
@@ -378,24 +378,27 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
378
  for s, e in zip(_nb[:-1], _nb[1:])
379
  ])
380
 
381
- fig = plt.figure(figsize=(18, 5.5))
382
  fig.patch.set_facecolor("#0e1015")
 
 
 
383
  axes = [
384
- fig.add_subplot(1, 3, 1),
385
- fig.add_subplot(1, 3, 2),
386
- fig.add_subplot(1, 3, 3, projection="3d"),
387
  ]
388
 
389
  # ── Left: 7×7 network heatmap ──────────────────────────────────────────
390
  ax = axes[0]
391
  ax.set_facecolor("#161922")
392
  im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
393
- ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=11, pad=14, fontweight="bold")
394
 
395
  ax.set_xticks(range(n_nets))
396
  ax.set_yticks(range(n_nets))
397
- ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=9, color="#ccc")
398
- ax.set_yticklabels(_nn, fontsize=9, color="#ccc")
399
  ax.tick_params(colors="#555", length=0)
400
  for sp in ax.spines.values():
401
  sp.set_color("#222")
@@ -421,7 +424,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
421
  for j in range(n_nets):
422
  txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666"
423
  ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
424
- fontsize=6.5, color=txt_color, zorder=3)
425
  if (i, j) in top5_cells:
426
  rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96,
427
  linewidth=1.8, edgecolor="#ffffff",
@@ -433,7 +436,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
433
  label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
434
  ax.annotate(label,
435
  xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
436
- fontsize=6, color="#fb923c", fontweight="600",
437
  arrowprops=dict(arrowstyle="-", color="#fb923c",
438
  lw=0.7, connectionstyle="arc3,rad=0.1"),
439
  ha="left", va="center", zorder=5)
@@ -441,7 +444,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
441
  cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
442
  cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
443
  plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
444
- cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=7.5)
445
 
446
  # ── Right: network importance bar chart ────────────────────────────────
447
  ax2 = axes[1]
@@ -452,9 +455,9 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
452
  bars = ax2.barh(range(n_nets), net_imp[order],
453
  color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
454
  ax2.set_yticks(range(n_nets))
455
- ax2.set_yticklabels([_nn[i] for i in order], fontsize=9.5, color="#ddd")
456
- ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=9)
457
- ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=11, pad=14, fontweight="bold")
458
  ax2.invert_yaxis()
459
  for sp in ["top", "right"]:
460
  ax2.spines[sp].set_visible(False)
@@ -465,14 +468,14 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
465
  x_max = net_imp.max()
466
  for bar, val in zip(bars, net_imp[order]):
467
  ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
468
- f"{val:.4f}", va="center", color="#555", fontsize=7.5)
469
 
470
  # ── 3D Brain Surface — top connections ────────────────────────────────────
471
  ax3 = axes[2]
472
  ax3.set_facecolor("#0e1015")
473
  ax3.grid(False)
474
  ax3.set_axis_off()
475
- ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=11, pad=4, fontweight="bold")
476
 
477
  # Transparent brain ellipsoid wireframe (MNI space approx)
478
  u = np.linspace(0, 2 * np.pi, 32)
@@ -489,7 +492,7 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
489
  size = 60 + imp_norm[k] * 260
490
  ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
491
  edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
492
- ax3.text(x, y, z + 7, name, fontsize=5.5, color=color,
493
  ha="center", va="bottom", fontweight="600", zorder=6)
494
 
495
  # Draw top-5 inter-network connections as lines, thickness ∝ saliency
@@ -508,11 +511,10 @@ def _saliency_figure(sal, p_mean, net_names=None, net_bounds=None, net_colors=No
508
 
509
  fig.suptitle(
510
  f"Gradient Saliency · p(ASD) = {p_mean:.3f} · 20-model LOSO ensemble · CC200 → Yeo-7 networks",
511
- color="#444", fontsize=8.5, y=1.02,
512
  )
513
- plt.tight_layout()
514
  buf = io.BytesIO()
515
- plt.savefig(buf, format="png", dpi=140, bbox_inches="tight", facecolor="#0e1015")
516
  plt.close(fig)
517
  buf.seek(0)
518
  return Image.open(buf).copy()
@@ -1008,35 +1010,50 @@ ARCHITECTURE = """
1008
  </div>
1009
  """
1010
 
1011
- AMD = f"""
1012
  <div>
1013
 
1014
- <!-- Benchmark chart first most impressive thing -->
1015
- <img src="data:image/png;base64,{AMD_BENCH_B64}" style="width:100%;border-radius:8px;margin-bottom:20px;border:1px solid #252a35"/>
1016
-
1017
- <!-- Two-column layout: stat grid left, pipeline right -->
1018
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px;margin-bottom:18px">
1019
-
1020
- <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
1021
- <div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">Hardware</div>
1022
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px">
1023
- <div><div style="font-size:1.5rem;font-weight:700;color:#fb923c;font-variant-numeric:tabular-nums">192<span style="font-size:0.75rem;color:#5e6675;font-weight:400"> GB</span></div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">HBM3 unified mem</div></div>
1024
- <div><div style="font-size:1.5rem;font-weight:700;color:#fb923c">bf16</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Full precision</div></div>
1025
- <div><div style="font-size:1.5rem;font-weight:700;color:#fb923c">30×</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Faster than CPU</div></div>
1026
- <div><div style="font-size:1.5rem;font-weight:700;color:#fb923c">94ms</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Per subject</div></div>
1027
- </div>
 
 
1028
  </div>
 
 
 
 
 
 
1029
 
1030
- <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px">
1031
- <div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">LoRA Fine-Tune</div>
1032
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:14px">
1033
- <div><div style="font-size:1.5rem;font-weight:700;color:#f4f4f5">7B</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Qwen2.5 params</div></div>
1034
- <div><div style="font-size:1.5rem;font-weight:700;color:#f4f4f5">r=16</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">LoRA rank</div></div>
1035
- <div><div style="font-size:1.5rem;font-weight:700;color:#f4f4f5">2K</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Domain examples</div></div>
1036
- <div><div style="font-size:1.5rem;font-weight:700;color:#f4f4f5">3</div><div style="color:#8b95a7;font-size:0.68rem;margin-top:3px;text-transform:uppercase;letter-spacing:0.8px">Epochs</div></div>
 
 
 
 
 
 
 
 
1037
  </div>
1038
  </div>
1039
-
1040
  </div>
1041
 
1042
  <!-- Fine-tune spec table -->
@@ -1044,7 +1061,8 @@ AMD = f"""
1044
  <table style="width:100%;border-collapse:collapse;font-size:0.85rem">
1045
  <tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Base model</td><td style="padding:10px 16px;color:#cbd5e1">Qwen/Qwen2.5-7B-Instruct <span style="color:#5e6675">· AMD partner model · ROCm native</span></td></tr>
1046
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Method</td><td style="padding:10px 16px;color:#cbd5e1">LoRA r=16 α=32 · q, k, v, o, gate, up, down projections · bf16 — no quantization needed</td></tr>
1047
- <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Training task</td><td style="padding:10px 16px;color:#cbd5e1">GCN ensemble output structured clinical referral letter with ICD-10 codes</td></tr>
 
1048
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Why MI300X?</td><td style="padding:10px 16px;color:#cbd5e1">192 GB unified HBM3 fits the full 7B model in bf16 without sharding — impossible on consumer GPUs. ROCm enables native PyTorch training with zero code changes.</td></tr>
1049
  </table>
1050
  </div>
 
11
  import torch
12
  import gradio as gr
13
 
14
+ from _charts import VAL_B64, AUC_B64
15
 
16
  _WINDOW_LEN = 50
17
  _STEP = 3
 
378
  for s, e in zip(_nb[:-1], _nb[1:])
379
  ])
380
 
381
+ fig = plt.figure(figsize=(20, 18))
382
  fig.patch.set_facecolor("#0e1015")
383
+ import matplotlib.gridspec as gridspec
384
+ gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.38, wspace=0.32,
385
+ height_ratios=[1.1, 1.0])
386
  axes = [
387
+ fig.add_subplot(gs[0, 0]), # heatmap
388
+ fig.add_subplot(gs[0, 1]), # bar chart
389
+ fig.add_subplot(gs[1, :], projection="3d"), # 3D brain — full bottom row
390
  ]
391
 
392
  # ── Left: 7×7 network heatmap ──────────────────────────────────────────
393
  ax = axes[0]
394
  ax.set_facecolor("#161922")
395
  im = ax.imshow(net_sal, cmap="inferno", aspect="auto", interpolation="nearest")
396
+ ax.set_title("FC Saliency by Brain Network", color="#bbb", fontsize=14, pad=16, fontweight="bold")
397
 
398
  ax.set_xticks(range(n_nets))
399
  ax.set_yticks(range(n_nets))
400
+ ax.set_xticklabels(_nn, rotation=40, ha="right", fontsize=12, color="#ccc")
401
+ ax.set_yticklabels(_nn, fontsize=12, color="#ccc")
402
  ax.tick_params(colors="#555", length=0)
403
  for sp in ax.spines.values():
404
  sp.set_color("#222")
 
424
  for j in range(n_nets):
425
  txt_color = "#111" if net_sal[i, j] > 0.6 * vmax else "#666"
426
  ax.text(j, i, f"{net_sal[i, j]:.3f}", ha="center", va="center",
427
+ fontsize=9, color=txt_color, zorder=3)
428
  if (i, j) in top5_cells:
429
  rect = plt.Rectangle((j - 0.48, i - 0.48), 0.96, 0.96,
430
  linewidth=1.8, edgecolor="#ffffff",
 
436
  label = f"#{rank+1} {_nn[i]}↔{_nn[j]}"
437
  ax.annotate(label,
438
  xy=(j, i), xytext=(n_nets - 0.3, rank * 0.85 - 0.3),
439
+ fontsize=8.5, color="#fb923c", fontweight="600",
440
  arrowprops=dict(arrowstyle="-", color="#fb923c",
441
  lw=0.7, connectionstyle="arc3,rad=0.1"),
442
  ha="left", va="center", zorder=5)
 
444
  cb = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
445
  cb.ax.yaxis.set_tick_params(color="#444", labelsize=7)
446
  plt.setp(cb.ax.yaxis.get_ticklabels(), color="#555")
447
+ cb.set_label("Mean |∂p(ASD)/∂FC|", color="#444", fontsize=10)
448
 
449
  # ── Right: network importance bar chart ────────────────────────────────
450
  ax2 = axes[1]
 
455
  bars = ax2.barh(range(n_nets), net_imp[order],
456
  color=[_nc[i] for i in order], alpha=0.88, edgecolor="none", height=0.65)
457
  ax2.set_yticks(range(n_nets))
458
+ ax2.set_yticklabels([_nn[i] for i in order], fontsize=12, color="#ddd")
459
+ ax2.set_xlabel("Mean gradient magnitude", color="#555", fontsize=11)
460
+ ax2.set_title("Network Importance for This Prediction", color="#bbb", fontsize=14, pad=16, fontweight="bold")
461
  ax2.invert_yaxis()
462
  for sp in ["top", "right"]:
463
  ax2.spines[sp].set_visible(False)
 
468
  x_max = net_imp.max()
469
  for bar, val in zip(bars, net_imp[order]):
470
  ax2.text(val + x_max * 0.015, bar.get_y() + bar.get_height() / 2,
471
+ f"{val:.4f}", va="center", color="#555", fontsize=10)
472
 
473
  # ── 3D Brain Surface — top connections ────────────────────────────────────
474
  ax3 = axes[2]
475
  ax3.set_facecolor("#0e1015")
476
  ax3.grid(False)
477
  ax3.set_axis_off()
478
+ ax3.set_title("Top Connections · 3D Brain", color="#bbb", fontsize=14, pad=8, fontweight="bold")
479
 
480
  # Transparent brain ellipsoid wireframe (MNI space approx)
481
  u = np.linspace(0, 2 * np.pi, 32)
 
492
  size = 60 + imp_norm[k] * 260
493
  ax3.scatter([x], [y], [z], c=color, s=size, zorder=5,
494
  edgecolors="#ffffff", linewidths=0.5, alpha=0.92)
495
+ ax3.text(x, y, z + 7, name, fontsize=8, color=color,
496
  ha="center", va="bottom", fontweight="600", zorder=6)
497
 
498
  # Draw top-5 inter-network connections as lines, thickness ∝ saliency
 
511
 
512
  fig.suptitle(
513
  f"Gradient Saliency · p(ASD) = {p_mean:.3f} · 20-model LOSO ensemble · CC200 → Yeo-7 networks",
514
+ color="#888", fontsize=12, y=1.01,
515
  )
 
516
  buf = io.BytesIO()
517
+ plt.savefig(buf, format="png", dpi=120, bbox_inches="tight", facecolor="#0e1015")
518
  plt.close(fig)
519
  buf.seek(0)
520
  return Image.open(buf).copy()
 
1010
  </div>
1011
  """
1012
 
1013
+ AMD = """
1014
  <div>
1015
 
1016
+ <!-- Stat grid: real benchmarked numbers -->
1017
+ <div style="display:grid;grid-template-columns:repeat(4,1fr);gap:14px;margin-bottom:18px">
1018
+ <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
1019
+ <div style="font-size:1.8rem;font-weight:700;color:#fb923c;font-variant-numeric:tabular-nums">~20<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> ms</span></div>
1020
+ <div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">End-to-end per patient</div>
1021
+ <div style="color:#5e6675;font-size:0.64rem;margin-top:3px">preprocess + 20-model ensemble</div>
1022
+ </div>
1023
+ <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
1024
+ <div style="font-size:1.8rem;font-weight:700;color:#fb923c">192<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> GB</span></div>
1025
+ <div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">HBM3 unified memory</div>
1026
+ <div style="color:#5e6675;font-size:0.64rem;margin-top:3px">7B model fits in bf16, no sharding</div>
1027
+ </div>
1028
+ <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
1029
+ <div style="font-size:1.8rem;font-weight:700;color:#fb923c">60</div>
1030
+ <div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">Models trained on MI300X</div>
1031
+ <div style="color:#5e6675;font-size:0.64rem;margin-top:3px">20 folds × 3 atlases · 150 epochs each</div>
1032
  </div>
1033
+ <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 16px;text-align:center">
1034
+ <div style="font-size:1.8rem;font-weight:700;color:#fb923c">ROCm<span style="font-size:0.8rem;color:#5e6675;font-weight:400"> 7.0</span></div>
1035
+ <div style="color:#8b95a7;font-size:0.67rem;margin-top:5px;text-transform:uppercase;letter-spacing:0.8px">Native PyTorch stack</div>
1036
+ <div style="color:#5e6675;font-size:0.64rem;margin-top:3px">zero code changes from CUDA</div>
1037
+ </div>
1038
+ </div>
1039
 
1040
+ <!-- How AMD is used -->
1041
+ <div style="background:#161922;border:1px solid #252a35;border-radius:8px;padding:18px 20px;margin-bottom:14px">
1042
+ <div style="color:#8b95a7;font-size:0.68rem;text-transform:uppercase;letter-spacing:1.5px;margin-bottom:14px;font-weight:500">AMD MI300X Usage in This Project</div>
1043
+ <div style="display:grid;grid-template-columns:1fr 1fr 1fr;gap:16px">
1044
+ <div>
1045
+ <div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px"> GCN Training</div>
1046
+ <div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">60 site-holdout models trained via PyTorch Lightning on ROCm 7.0. K=32 modes, 150 epochs, 3 atlases (CC200, AAL, HO).</div>
1047
+ </div>
1048
+ <div>
1049
+ <div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px">② LLM Fine-Tuning</div>
1050
+ <div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">Qwen2.5-7B fine-tuned with LoRA (r=16, bf16) on 2K domain examples. 192GB HBM3 fits the full model without sharding.</div>
1051
+ </div>
1052
+ <div>
1053
+ <div style="color:#fb923c;font-size:0.78rem;font-weight:600;margin-bottom:6px">③ Live LLM Inference</div>
1054
+ <div style="color:#cbd5e1;font-size:0.82rem;line-height:1.55">Clinical reports generated in real-time via vLLM on the MI300X droplet. Every report you see is live AMD inference.</div>
1055
  </div>
1056
  </div>
 
1057
  </div>
1058
 
1059
  <!-- Fine-tune spec table -->
 
1061
  <table style="width:100%;border-collapse:collapse;font-size:0.85rem">
1062
  <tr><td style="padding:10px 16px;color:#8b95a7;width:150px;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Base model</td><td style="padding:10px 16px;color:#cbd5e1">Qwen/Qwen2.5-7B-Instruct <span style="color:#5e6675">· AMD partner model · ROCm native</span></td></tr>
1063
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Method</td><td style="padding:10px 16px;color:#cbd5e1">LoRA r=16 α=32 · q, k, v, o, gate, up, down projections · bf16 — no quantization needed</td></tr>
1064
+ <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">GCN inference</td><td style="padding:10px 16px;color:#cbd5e1">~20ms end-to-end per patient · benchmarked on AMD MI300X · ROCm 7.0 · PyTorch 2.5.1</td></tr>
1065
+ <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">LLM serving</td><td style="padding:10px 16px;color:#cbd5e1">vLLM on AMD MI300X · OpenAI-compatible API · live inference for every clinical report</td></tr>
1066
  <tr style="border-top:1px solid #252a35"><td style="padding:10px 16px;color:#8b95a7;font-size:0.76rem;text-transform:uppercase;letter-spacing:0.5px">Why MI300X?</td><td style="padding:10px 16px;color:#cbd5e1">192 GB unified HBM3 fits the full 7B model in bf16 without sharding — impossible on consumer GPUs. ROCm enables native PyTorch training with zero code changes.</td></tr>
1067
  </table>
1068
  </div>