anugrahhu commited on
Commit
3080a66
·
verified ·
1 Parent(s): eb2a494

dashboard: synthesize PNGs on demand + cache-bust + pass --evidence_dir to vanilla

Browse files
Files changed (1) hide show
  1. space/training/app.py +263 -6
space/training/app.py CHANGED
@@ -14,19 +14,22 @@ work runs in a background thread so the HTTP server stays responsive.
14
 
15
  from __future__ import annotations
16
 
 
 
17
  import json
18
  import logging
19
  import os
 
20
  import subprocess
21
  import sys
22
  import threading
23
  import time
24
  from datetime import datetime, timezone
25
  from pathlib import Path
26
- from typing import Any, Dict, Optional
27
 
28
  from fastapi import FastAPI, HTTPException
29
- from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse
30
  from fastapi.staticfiles import StaticFiles
31
 
32
 
@@ -179,6 +182,9 @@ def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
179
  backend = str(config.get("training_backend", "vanilla")).lower()
180
  if backend == "vanilla":
181
  python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
 
 
 
182
  return [
183
  python_bin, "-m", "training.training_script",
184
  "--model_name", config["model_name"],
@@ -186,7 +192,10 @@ def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
186
  "--total_episodes", str(config["total_episodes"]),
187
  "--max_steps", str(config["max_steps"]),
188
  "--num_generations", str(config["num_generations"]),
 
 
189
  "--output_dir", config["output_dir"],
 
190
  ]
191
 
192
  if backend != "unsloth":
@@ -501,6 +510,232 @@ def _start_training(config: Dict[str, Any]) -> None:
501
  STATE.thread.start()
502
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  # ── FastAPI app ──────────────────────────────────────────────────────────
505
 
506
 
@@ -513,6 +748,7 @@ _HTML = """\
513
  <head>
514
  <meta charset=utf-8>
515
  <title>CERNenv Trainer</title>
 
516
  <style>
517
  body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto;
518
  max-width: 1000px; color:#111; padding: 0 1rem; line-height:1.5 }
@@ -710,13 +946,34 @@ def evidence_index() -> JSONResponse:
710
 
711
  @app.get("/evidence/{name}")
712
  def evidence_file(name: str):
713
- """Serve a single evidence artifact (PNG/CSV/JSON/MD) by filename."""
 
 
 
 
 
 
 
714
  if "/" in name or ".." in name:
715
  raise HTTPException(status_code=400, detail="invalid name")
716
  target = EVIDENCE_DIR / name
717
- if not target.exists() or not target.is_file():
718
- raise HTTPException(status_code=404, detail=f"{name} not found")
719
- return FileResponse(target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
 
721
 
722
  @app.get("/logs", response_class=PlainTextResponse)
 
14
 
15
  from __future__ import annotations
16
 
17
+ import ast
18
+ import io
19
  import json
20
  import logging
21
  import os
22
+ import re
23
  import subprocess
24
  import sys
25
  import threading
26
  import time
27
  from datetime import datetime, timezone
28
  from pathlib import Path
29
+ from typing import Any, Dict, List, Optional
30
 
31
  from fastapi import FastAPI, HTTPException
32
+ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response
33
  from fastapi.staticfiles import StaticFiles
34
 
35
 
 
182
  backend = str(config.get("training_backend", "vanilla")).lower()
183
  if backend == "vanilla":
184
  python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
185
+ # vanilla now accepts --evidence_dir / --checkpoint_eval_* so the
186
+ # backported EvidenceCallback writes evidence/*.csv + plots into
187
+ # the same directory the dashboard serves from.
188
  return [
189
  python_bin, "-m", "training.training_script",
190
  "--model_name", config["model_name"],
 
192
  "--total_episodes", str(config["total_episodes"]),
193
  "--max_steps", str(config["max_steps"]),
194
  "--num_generations", str(config["num_generations"]),
195
+ "--checkpoint_eval_steps", str(config["checkpoint_eval_steps"]),
196
+ "--checkpoint_eval_episodes", str(config["checkpoint_eval_episodes"]),
197
  "--output_dir", config["output_dir"],
198
+ "--evidence_dir", config["evidence_dir"],
199
  ]
200
 
201
  if backend != "unsloth":
 
510
  STATE.thread.start()
511
 
512
 
513
+ # ── On-demand evidence-PNG synthesis ─────────────────────────────────────
514
+ #
515
+ # The vanilla GRPO backend (training/training_script.py) does not register
516
+ # an EvidenceCallback, so it never writes training_log.csv /
517
+ # reward_components.csv mid-run. The unsloth backend does, but a Space that
518
+ # happens to be running the vanilla path leaves those evidence cards empty
519
+ # until post-eval — and even then they stay empty because the underlying
520
+ # CSVs were never produced.
521
+ #
522
+ # To keep the dashboard live without restarting the in-flight run, we
523
+ # synthesise both PNGs on demand by parsing the TRL log dicts that the
524
+ # trainer prints to stdout (captured in training/runs/training.log by
525
+ # _stream_subprocess). The unsloth path still gets its richer
526
+ # component-level CSVs as before; this only kicks in when the file is
527
+ # missing or older than the captured log.
528
+
529
+ # Matches a tqdm progress line like " 53%|█████▎ | 190/360 [12:31<10:06,
530
+ # 3.57s/it]" emitted just before each TRL log dict, so we can attribute a
531
+ # dict to the correct global_step instead of guessing from logging_steps.
532
+ _TQDM_PROGRESS_RE = re.compile(r"\b(\d+)\s*/\s*(\d+)\s*\[")
533
+
534
+
535
+ def _parse_training_log_dicts(text: str) -> List[Dict[str, Any]]:
536
+ """Extract per-log-step rows from a captured TRL stdout log.
537
+
538
+ TRL prints a Python dict-repr on each ``logging_steps`` boundary.
539
+ We pair each dict with the most recent tqdm progress line so the
540
+ plotted x-axis reflects ``global_step`` rather than dict-arrival
541
+ order. Lines that do not parse cleanly are silently skipped.
542
+ """
543
+ rows: List[Dict[str, Any]] = []
544
+ last_step: Optional[int] = None
545
+ for raw in text.splitlines():
546
+ m = _TQDM_PROGRESS_RE.search(raw)
547
+ if m:
548
+ try:
549
+ last_step = int(m.group(1))
550
+ except ValueError:
551
+ pass
552
+ continue
553
+ s = raw.strip()
554
+ if not (s.startswith("{") and s.endswith("}")):
555
+ continue
556
+ if "'loss'" not in s and "'reward'" not in s and "'kl'" not in s:
557
+ continue
558
+ try:
559
+ d = ast.literal_eval(s)
560
+ except (ValueError, SyntaxError):
561
+ continue
562
+ if not isinstance(d, dict):
563
+ continue
564
+ reward = (
565
+ d.get("reward")
566
+ or d.get("rewards/mean")
567
+ or d.get("rewards/reward_fn/mean")
568
+ )
569
+ reward_std = (
570
+ d.get("reward_std")
571
+ or d.get("rewards/std")
572
+ or d.get("rewards/reward_fn/std")
573
+ )
574
+ rows.append({
575
+ "step": last_step if last_step is not None else len(rows),
576
+ "loss": d.get("loss"),
577
+ "reward": reward,
578
+ "reward_std": reward_std,
579
+ "kl": d.get("kl"),
580
+ "grad_norm": d.get("grad_norm"),
581
+ "learning_rate": d.get("learning_rate"),
582
+ "epoch": d.get("epoch"),
583
+ "frac_reward_zero_std": d.get("frac_reward_zero_std"),
584
+ "completions_mean_length": d.get("completions/mean_length"),
585
+ "completions_clipped_ratio": d.get("completions/clipped_ratio"),
586
+ })
587
+ return rows
588
+
589
+
590
+ def _try_matplotlib():
591
+ try:
592
+ import matplotlib # type: ignore
593
+ matplotlib.use("Agg")
594
+ import matplotlib.pyplot as plt # type: ignore
595
+ return plt
596
+ except Exception as exc: # pragma: no cover - plotting is best-effort
597
+ logger.warning("matplotlib unavailable: %s", exc)
598
+ return None
599
+
600
+
601
+ def _png_bytes(fig) -> bytes:
602
+ buf = io.BytesIO()
603
+ fig.savefig(buf, format="png", dpi=140)
604
+ return buf.getvalue()
605
+
606
+
607
+ def _read_log_text() -> Optional[str]:
608
+ if not LOG_FILE.exists():
609
+ return None
610
+ try:
611
+ return LOG_FILE.read_text(errors="replace")
612
+ except OSError:
613
+ return None
614
+
615
+
616
+ def _synth_training_curve_png() -> Optional[bytes]:
617
+ """Render a 2-panel reward/loss curve from the captured TRL stdout log."""
618
+ text = _read_log_text()
619
+ if not text:
620
+ return None
621
+ rows = _parse_training_log_dicts(text)
622
+ if not rows:
623
+ return None
624
+ plt = _try_matplotlib()
625
+ if plt is None:
626
+ return None
627
+
628
+ steps = [r["step"] for r in rows]
629
+ rewards = [(s, r["reward"]) for s, r in zip(steps, rows) if r["reward"] is not None]
630
+ losses = [(s, r["loss"]) for s, r in zip(steps, rows) if r["loss"] is not None]
631
+
632
+ fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
633
+ if rewards:
634
+ axes[0].plot([x for x, _ in rewards], [y for _, y in rewards],
635
+ lw=1.6, color="#1d4ed8")
636
+ axes[0].set_ylabel("mean reward")
637
+ axes[0].set_title(
638
+ "CERNenv GRPO training — reward over steps "
639
+ f"(synthesised from {len(rewards)} log events)"
640
+ )
641
+ axes[0].grid(alpha=0.25)
642
+ if losses:
643
+ axes[1].plot([x for x, _ in losses], [y for _, y in losses],
644
+ lw=1.6, color="#c026d3")
645
+ axes[1].set_ylabel("GRPO loss")
646
+ axes[1].set_xlabel("training step")
647
+ axes[1].grid(alpha=0.25)
648
+ fig.tight_layout()
649
+ try:
650
+ return _png_bytes(fig)
651
+ finally:
652
+ plt.close(fig)
653
+
654
+
655
+ def _synth_reward_components_png() -> Optional[bytes]:
656
+ """Best-effort reward-components view derived from TRL stdout.
657
+
658
+ The unsloth callback writes a true terminal-vs-shaping split into
659
+ reward_components.csv. The vanilla backend only emits aggregate
660
+ reward in the TRL log dict, so here we fall back to plotting reward
661
+ mean ± std (group dispersion) and KL on a second axis. This still
662
+ surfaces the "watch dispersion, not just the mean" view the FAQ
663
+ recommends — at least until a real callback writes a richer CSV.
664
+ """
665
+ text = _read_log_text()
666
+ if not text:
667
+ return None
668
+ rows = _parse_training_log_dicts(text)
669
+ if not rows:
670
+ return None
671
+ plt = _try_matplotlib()
672
+ if plt is None:
673
+ return None
674
+
675
+ steps = [r["step"] for r in rows]
676
+ rmean = [r.get("reward") for r in rows]
677
+ rstd = [r.get("reward_std") for r in rows]
678
+ kls = [r.get("kl") for r in rows]
679
+ fzero = [r.get("frac_reward_zero_std") for r in rows]
680
+ clen = [r.get("completions_mean_length") for r in rows]
681
+
682
+ fig, axes = plt.subplots(2, 1, figsize=(8, 6.5), sharex=True)
683
+ band = [(s, m, sd) for s, m, sd in zip(steps, rmean, rstd) if m is not None]
684
+ if band:
685
+ sx = [b[0] for b in band]
686
+ rm = [b[1] for b in band]
687
+ rs = [b[2] if b[2] is not None else 0.0 for b in band]
688
+ axes[0].plot(sx, rm, lw=2.0, color="#0f172a", label="reward (group mean)")
689
+ axes[0].fill_between(
690
+ sx,
691
+ [m - s for m, s in zip(rm, rs)],
692
+ [m + s for m, s in zip(rm, rs)],
693
+ alpha=0.18, color="#1d4ed8", label="±1 std (group dispersion)",
694
+ )
695
+ axes[0].set_ylabel("reward at logging step")
696
+ axes[0].set_title(
697
+ "CERNenv reward — group mean ± dispersion "
698
+ "(stdout-derived; install EvidenceCallback for terminal vs shaping split)"
699
+ )
700
+ axes[0].grid(alpha=0.25)
701
+ axes[0].legend(loc="lower right", fontsize=9)
702
+
703
+ kl_pts = [(s, k) for s, k in zip(steps, kls) if k is not None]
704
+ if kl_pts:
705
+ axes[1].plot([p[0] for p in kl_pts], [p[1] for p in kl_pts],
706
+ lw=1.5, color="#9333ea", label="KL divergence")
707
+ axes[1].set_ylabel("KL", color="#9333ea")
708
+ fz_pts = [(s, f) for s, f in zip(steps, fzero) if f is not None]
709
+ cl_pts = [(s, c) for s, c in zip(steps, clen) if c is not None]
710
+ if fz_pts or cl_pts:
711
+ ax2 = axes[1].twinx()
712
+ if fz_pts:
713
+ ax2.plot([p[0] for p in fz_pts], [p[1] for p in fz_pts],
714
+ "o-", lw=1.0, ms=3, color="#ea580c",
715
+ label="frac rollouts with zero-std (saturation)")
716
+ ax2.set_ylim(-0.02, 1.05)
717
+ if cl_pts:
718
+ cmax = max(p[1] for p in cl_pts) or 1.0
719
+ ax2.plot([p[0] for p in cl_pts], [p[1] / cmax for p in cl_pts],
720
+ "x:", lw=1.0, ms=4, color="#16a34a",
721
+ label=f"completion mean length / {cmax:.0f}")
722
+ ax2.set_ylabel("auxiliary (right axis, normalised)", color="#475569")
723
+ ax2.legend(loc="upper right", fontsize=8)
724
+ axes[1].set_xlabel("training step")
725
+ axes[1].grid(alpha=0.25)
726
+ fig.tight_layout()
727
+ try:
728
+ return _png_bytes(fig)
729
+ finally:
730
+ plt.close(fig)
731
+
732
+
733
+ _SYNTH_HANDLERS = {
734
+ "training_curve.png": _synth_training_curve_png,
735
+ "reward_components.png": _synth_reward_components_png,
736
+ }
737
+
738
+
739
  # ── FastAPI app ──────────────────────────────────────────────────────────
740
 
741
 
 
748
  <head>
749
  <meta charset=utf-8>
750
  <title>CERNenv Trainer</title>
751
+ <meta http-equiv="refresh" content="60">
752
  <style>
753
  body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto;
754
  max-width: 1000px; color:#111; padding: 0 1rem; line-height:1.5 }
 
946
 
947
  @app.get("/evidence/{name}")
948
  def evidence_file(name: str):
949
+ """Serve a single evidence artifact (PNG/CSV/JSON/MD) by filename.
950
+
951
+ For ``training_curve.png`` and ``reward_components.png`` we fall back
952
+ to on-demand synthesis from the captured TRL stdout log when the
953
+ underlying file does not yet exist on disk — which is the normal
954
+ state of affairs when the vanilla backend is running and no
955
+ EvidenceCallback has had a chance to write the source CSV.
956
+ """
957
  if "/" in name or ".." in name:
958
  raise HTTPException(status_code=400, detail="invalid name")
959
  target = EVIDENCE_DIR / name
960
+ if target.exists() and target.is_file():
961
+ return FileResponse(target)
962
+
963
+ handler = _SYNTH_HANDLERS.get(name)
964
+ if handler is not None:
965
+ try:
966
+ png = handler()
967
+ except Exception as exc: # pragma: no cover - synthesis is best-effort
968
+ logger.warning("on-demand synthesis of %s failed: %s", name, exc)
969
+ png = None
970
+ if png:
971
+ return Response(
972
+ content=png,
973
+ media_type="image/png",
974
+ headers={"Cache-Control": "no-store, max-age=0"},
975
+ )
976
+ raise HTTPException(status_code=404, detail=f"{name} not found")
977
 
978
 
979
  @app.get("/logs", response_class=PlainTextResponse)