"""FastAPI control panel for the CERNenv trainer Space. Endpoints: GET / → status page (HTML) GET /status → JSON status of the current training run GET /metrics → JSON snapshot of reward / success rate GET /logs → tail of the training log POST /train → start (or restart) a training run GET /health → liveness probe Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training work runs in a background thread so the HTTP server stays responsive. """ from __future__ import annotations import ast import io import json import logging import os import re import subprocess import sys import threading import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional from fastapi import FastAPI, HTTPException, Request from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response from fastapi.staticfiles import StaticFiles logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) def _resolve_repo_root() -> Path: env_root = os.environ.get("CERNENV_ROOT") candidates = [] if env_root: candidates.append(Path(env_root)) candidates.extend([ Path("/home/user/app"), Path(__file__).resolve().parent.parent.parent, ]) for p in candidates: try: if p.exists(): return p.resolve() except OSError: continue return candidates[-1].resolve() REPO_ROOT = _resolve_repo_root() LOG_DIR = REPO_ROOT / "training" / "runs" try: LOG_DIR.mkdir(parents=True, exist_ok=True) except OSError as exc: # pragma: no cover - read-only filesystem fallback logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc) LOG_DIR = Path("/tmp/cernenv-runs") LOG_DIR.mkdir(parents=True, exist_ok=True) LOG_FILE = LOG_DIR / "training.log" EVIDENCE_DIR = REPO_ROOT / "evidence" try: EVIDENCE_DIR.mkdir(parents=True, exist_ok=True) except OSError: # pragma: no cover EVIDENCE_DIR = Path("/tmp/cernenv-evidence") EVIDENCE_DIR.mkdir(parents=True, exist_ok=True) METRICS_FILE = EVIDENCE_DIR / "before_after_metrics.json" def _env(name: str, default: str) -> str: return os.environ.get(name, default) def _detect_gpus() -> int: try: import torch # type: ignore if torch.cuda.is_available(): return torch.cuda.device_count() except Exception: pass try: out = subprocess.run( ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], capture_output=True, text=True, timeout=5, ) return len([l for l in out.stdout.splitlines() if l.strip()]) except Exception: return 0 _NUM_GPUS = _detect_gpus() def _bool_env(name: str, default: str) -> bool: return _env(name, default).strip().lower() in ("1", "true", "yes", "on") CONFIG = { "training_backend": _env("TRAINING_BACKEND", "vanilla"), "model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"), "difficulty": _env("DIFFICULTY", "easy"), "curriculum": _env("CURRICULUM", "0") == "1", "curriculum_promote": float(_env("CURRICULUM_PROMOTE", "0.55")), "curriculum_demote": float(_env("CURRICULUM_DEMOTE", "0.10")), "total_episodes": int(_env("TOTAL_EPISODES", "120")), "max_steps": int(_env("MAX_STEPS", "12")), "num_generations": int(_env("NUM_GENERATIONS", "4")), "checkpoint_eval_steps": int(_env("CHECKPOINT_EVAL_STEPS", "25")), "checkpoint_eval_episodes": int(_env("CHECKPOINT_EVAL_EPISODES", "8")), "eval_episodes": int(_env("EVAL_EPISODES", "8")), "output_dir": _env("OUTPUT_DIR", "runs/vanilla-grpo"), "evidence_dir": _env("EVIDENCE_DIR", "evidence"), "num_gpus": int(_env("NUM_GPUS", "1")), "hf_username": _env("HF_USERNAME", "anugrahhu"), "push_repo": _env( "PUSH_REPO", f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m", ), "autostart": _env("AUTOSTART", "0") == "1", # ── SFT warm-start phase (defeats v1's claim-avoidance reward hack # by giving GRPO a non-zero prior over correct trajectories) ───── "sft_warmstart": _bool_env("SFT_WARMSTART", "false"), "sft_num_episodes": int(_env("SFT_NUM_EPISODES", "200")), "sft_max_steps": int(_env("SFT_MAX_STEPS", "8")), "sft_epochs": int(_env("SFT_EPOCHS", "1")), "sft_lr": float(_env("SFT_LR", "1e-5")), "sft_difficulty": _env("SFT_DIFFICULTY", "mixed"), "sft_out_dir": _env("SFT_OUT_DIR", "runs/sft-warmstart"), } # ── Run state ──────────────────────────────────────────────────────────── class RunState: def __init__(self) -> None: self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None self.process: Optional[subprocess.Popen] = None self.status: str = "idle" # idle | running | finished | failed self.started_at: Optional[str] = None self.finished_at: Optional[str] = None self.last_error: Optional[str] = None self.last_config: Dict[str, Any] = {} def to_dict(self) -> Dict[str, Any]: with self.lock: return { "status": self.status, "started_at": self.started_at, "finished_at": self.finished_at, "last_error": self.last_error, "last_config": self.last_config, } STATE = RunState() # ── Training pipeline ──────────────────────────────────────────────────── def _stream_subprocess(cmd: list[str], log_handle) -> int: log_handle.write(f"\n$ {' '.join(cmd)}\n") log_handle.flush() proc = subprocess.Popen( cmd, cwd=str(REPO_ROOT), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, env={**os.environ, "PYTHONPATH": str(REPO_ROOT)}, ) STATE.process = proc assert proc.stdout is not None for line in proc.stdout: log_handle.write(line) log_handle.flush() rc = proc.wait() log_handle.write(f"[exit code {rc}]\n") log_handle.flush() STATE.process = None return rc def _build_sft_warmstart_cmd(config: Dict[str, Any]) -> list[str]: """Compose the SFT-warm-start subprocess command. Always uses the system Python so GRPO + SFT share the same transformers + trl pin in space/training/requirements.txt. """ python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable return [ python_bin, "-m", "training.sft_warmstart", "--out_dir", config["sft_out_dir"], "--num_episodes", str(config["sft_num_episodes"]), "--max_steps", str(config["sft_max_steps"]), "--epochs", str(config["sft_epochs"]), "--lr", str(config["sft_lr"]), "--base_model", config["model_name"], "--difficulty", config["sft_difficulty"], "--evidence_dir", config["evidence_dir"], ] def _build_training_cmd(config: Dict[str, Any]) -> list[str]: """Compose the selected training launcher. When ``sft_warmstart`` is on, ``model_name`` is expected to already have been overwritten with the SFT output directory by the caller (``_training_pipeline``), so this function never has to know about the SFT phase explicitly — it just trains GRPO from whatever path is sitting in ``model_name``. """ backend = str(config.get("training_backend", "vanilla")).lower() if backend == "vanilla": python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable # vanilla now accepts --evidence_dir / --checkpoint_eval_* so the # backported EvidenceCallback writes evidence/*.csv + plots into # the same directory the dashboard serves from. return [ python_bin, "-m", "training.training_script", "--model_name", config["model_name"], "--difficulty", config["difficulty"], "--total_episodes", str(config["total_episodes"]), "--max_steps", str(config["max_steps"]), "--num_generations", str(config["num_generations"]), "--checkpoint_eval_steps", str(config["checkpoint_eval_steps"]), "--checkpoint_eval_episodes", str(config["checkpoint_eval_episodes"]), "--output_dir", config["output_dir"], "--evidence_dir", config["evidence_dir"], ] if backend != "unsloth": raise ValueError(f"unknown TRAINING_BACKEND={backend!r}") base = [ "-m", "training.training_unsloth", "--model_name", config["model_name"], "--difficulty", config["difficulty"], "--total_episodes", str(config["total_episodes"]), "--max_steps", str(config["max_steps"]), "--num_generations", str(config["num_generations"]), "--checkpoint_eval_steps", str(config["checkpoint_eval_steps"]), "--checkpoint_eval_episodes", str(config["checkpoint_eval_episodes"]), "--output_dir", config["output_dir"], "--evidence_dir", config["evidence_dir"], ] if config.get("curriculum"): base.extend([ "--curriculum", "--curriculum_promote", str(config["curriculum_promote"]), "--curriculum_demote", str(config["curriculum_demote"]), ]) n = max(int(config.get("num_gpus", 1)), 1) if n > 1: return ["accelerate", "launch", "--num_processes", str(n), "--mixed_precision", "bf16"] + base return [sys.executable] + base def _build_eval_cmd( *, model_name: str, difficulty: str, episodes: str, max_steps: str, tag: str, out: str, backend: str, adapter_dir: Optional[str] = None, ) -> list[str]: cmd = [ sys.executable, "-m", "training.evaluate", "--model_name", model_name, "--difficulty", difficulty, "--episodes", episodes, "--max_steps", max_steps, "--tag", tag, "--out", out, ] if adapter_dir: cmd.extend(["--adapter_dir", adapter_dir]) if backend == "vanilla": cmd.append("--no_unsloth") return cmd def _push_model_folder_to_hub(*, output_dir: Path, repo_id: str, base_model: str, log) -> None: """Upload a vanilla transformers model directory to the Hub.""" token = os.environ.get("HF_TOKEN") if not token: log.write("\n[skip] HF_TOKEN not set — model not pushed\n") log.flush() return try: from huggingface_hub import HfApi api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", exist_ok=True) api.upload_folder( folder_path=str(output_dir), repo_id=repo_id, repo_type="model", commit_message=f"Upload vanilla GRPO model based on {base_model}", ) log.write(f"\n[ok] uploaded model → https://huggingface.co/{repo_id}\n") log.flush() except Exception as exc: log.write(f"\n[warn] model push failed: {exc}\n") log.flush() def _push_evidence_to_hub(*, evidence_dir: Path, repo_id: str, log) -> None: """Upload the entire evidence/ directory to the model repo.""" token = os.environ.get("HF_TOKEN") if not token: log.write("\n[skip] HF_TOKEN not set — evidence not pushed\n") log.flush() return try: from huggingface_hub import HfApi api = HfApi(token=token) api.upload_folder( folder_path=str(evidence_dir), repo_id=repo_id, repo_type="model", path_in_repo="evidence", commit_message="Upload CERNenv training evidence (curves, evals, plots)", ) log.write(f"\n[ok] uploaded evidence/ → https://huggingface.co/{repo_id}/tree/main/evidence\n") log.flush() except Exception as exc: log.write(f"\n[warn] evidence push failed: {exc}\n") log.flush() def _training_pipeline(config: Dict[str, Any]) -> None: started = datetime.now(timezone.utc).isoformat() with STATE.lock: STATE.status = "running" STATE.started_at = started STATE.finished_at = None STATE.last_error = None STATE.last_config = dict(config) evidence_dir = Path(config["evidence_dir"]).resolve() evidence_dir.mkdir(parents=True, exist_ok=True) LOG_FILE.parent.mkdir(parents=True, exist_ok=True) with open(LOG_FILE, "a") as log: log.write(f"\n=== Training started {started} ===\n") log.write(json.dumps(config, indent=2) + "\n") log.flush() try: output_dir = config["output_dir"] difficulty = config["difficulty"] max_steps = str(config["max_steps"]) eval_episodes = str(config["eval_episodes"]) model_name = config["model_name"] push_repo = config["push_repo"] evidence_str = config["evidence_dir"] backend = str(config.get("training_backend", "vanilla")).lower() pre_jsonl = f"{evidence_str}/pre_eval.jsonl" post_jsonl = f"{evidence_str}/post_eval.jsonl" log.write("\n--- baseline sanity check (random / heuristic / oracle) ---\n") log.flush() for agent in ("random", "heuristic", "oracle"): _stream_subprocess( [ sys.executable, "-m", "scripts.run_agent", "--agent", agent, "--difficulty", difficulty, "--episodes", "3", "--quiet", ], log, ) log.write(f"\n--- pre-train evaluation ({eval_episodes} eps) ---\n") log.flush() rc = _stream_subprocess( _build_eval_cmd( model_name=model_name, difficulty=difficulty, episodes=eval_episodes, max_steps=max_steps, tag="pre_train", out=pre_jsonl, backend=backend, ), log, ) if rc != 0: # don't abort — we still want training + post-eval evidence. log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n") log.flush() if config.get("sft_warmstart"): # Phase 1 — SFT warm-start. Produces a *full* causal-LM # checkpoint at config['sft_out_dir'] (LoRA adapters are # merged in by training/sft_warmstart.py) so we can hand # it to GRPO as a drop-in --model_name. sft_out = config["sft_out_dir"] log.write( f"\n--- SFT warm-start ({config['sft_num_episodes']} oracle " f"episodes, epochs={config['sft_epochs']}, → {sft_out}) ---\n" ) log.flush() sft_rc = _stream_subprocess(_build_sft_warmstart_cmd(config), log) if sft_rc != 0: raise RuntimeError(f"SFT warm-start failed (rc={sft_rc})") log.write( f"\n[ok] SFT done; switching GRPO base model " f"{config['model_name']} → {sft_out}\n" ) log.flush() config["model_name"] = sft_out # Keep the *base* HF id around for evaluator commands — # tokenizer files in the SFT directory are saved by the # SFT script, but evaluation will load from this dir # directly, so no further path bookkeeping is required. log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n") log.flush() rc = _stream_subprocess(_build_training_cmd(config), log) if rc != 0: raise RuntimeError(f"training failed (rc={rc})") # Cold-load the trained artifact before burning time on post-eval. log.write( f"\n--- trained artifact smoke test " f"(loading {output_dir} cold-start, 2 eps) ---\n" ) log.flush() smoke_model = output_dir if backend == "vanilla" else model_name smoke_adapter = None if backend == "vanilla" else output_dir rc = _stream_subprocess( _build_eval_cmd( model_name=smoke_model, adapter_dir=smoke_adapter, difficulty=difficulty, episodes="2", max_steps=max_steps, tag="smoke", out=f"{evidence_str}/smoke_eval.jsonl", backend=backend, ), log, ) if rc != 0: raise RuntimeError( f"trained artifact smoke test failed (rc={rc}); refusing to push " f"unloadable output to the Hub. Inspect {output_dir}." ) log.write(f"\n--- post-train evaluation ({eval_episodes} eps) ---\n") log.flush() post_model = output_dir if backend == "vanilla" else model_name post_adapter = None if backend == "vanilla" else output_dir rc = _stream_subprocess( _build_eval_cmd( model_name=post_model, adapter_dir=post_adapter, difficulty=difficulty, episodes=eval_episodes, max_steps=max_steps, tag="post_train", out=post_jsonl, backend=backend, ), log, ) if rc != 0: log.write(f"\n[warn] post-train eval failed (rc={rc}); evidence will be partial\n") log.flush() log.write("\n--- evidence: before/after summary, distribution, trajectories ---\n") log.flush() try: from training.evidence import ( EvidencePaths, render_before_after, render_sample_trajectories, render_training_curve, render_reward_components, render_checkpoint_progression, ) paths = EvidencePaths(root=Path(evidence_str)) paths.ensure() metrics = render_before_after( pre_jsonl=Path(pre_jsonl), post_jsonl=Path(post_jsonl), summary_png=paths.before_after_summary_png, distribution_png=paths.reward_distribution_png, metrics_json=paths.before_after_metrics_json, ) render_sample_trajectories( pre_jsonl=Path(pre_jsonl), post_jsonl=Path(post_jsonl), md_path=paths.sample_trajectories_md, ) render_training_curve(paths.training_log_csv, paths.training_curve_png) render_reward_components( paths.reward_components_csv, paths.reward_components_png, ) render_checkpoint_progression( paths.checkpoint_evals_csv, paths.checkpoint_progression_png, ) log.write(json.dumps(metrics, indent=2) + "\n") log.flush() except Exception as exc: log.write(f"[warn] evidence rendering failed: {exc}\n") log.flush() if os.environ.get("HF_TOKEN"): if backend == "vanilla": log.write("\n--- push vanilla model to Hub ---\n") log.flush() _push_model_folder_to_hub( output_dir=Path(output_dir), repo_id=push_repo, base_model=model_name, log=log, ) else: log.write("\n--- push adapters to Hub ---\n") log.flush() _stream_subprocess( [ sys.executable, "-m", "scripts.push_to_hub", "model", "--adapter_dir", output_dir, "--repo_id", push_repo, "--base_model", model_name, ], log, ) _push_evidence_to_hub( evidence_dir=evidence_dir, repo_id=push_repo, log=log, ) else: log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n") log.flush() with STATE.lock: STATE.status = "finished" except Exception as exc: logger.exception("training pipeline failed") with STATE.lock: STATE.status = "failed" STATE.last_error = str(exc) finally: finished = datetime.now(timezone.utc).isoformat() log.write(f"\n=== Training ended {finished} ===\n") log.flush() with STATE.lock: STATE.finished_at = finished def _start_training(config: Dict[str, Any]) -> None: with STATE.lock: if STATE.status == "running": raise RuntimeError("a training run is already in progress") STATE.thread = threading.Thread( target=_training_pipeline, args=(config,), name="cernenv-trainer", daemon=True, ) STATE.thread.start() # ── On-demand evidence-PNG synthesis ───────────────────────────────────── # # The vanilla GRPO backend (training/training_script.py) does not register # an EvidenceCallback, so it never writes training_log.csv / # reward_components.csv mid-run. The unsloth backend does, but a Space that # happens to be running the vanilla path leaves those evidence cards empty # until post-eval — and even then they stay empty because the underlying # CSVs were never produced. # # To keep the dashboard live without restarting the in-flight run, we # synthesise both PNGs on demand by parsing the TRL log dicts that the # trainer prints to stdout (captured in training/runs/training.log by # _stream_subprocess). The unsloth path still gets its richer # component-level CSVs as before; this only kicks in when the file is # missing or older than the captured log. # Matches a tqdm progress line like " 53%|█████▎ | 190/360 [12:31<10:06, # 3.57s/it]" emitted just before each TRL log dict, so we can attribute a # dict to the correct global_step instead of guessing from logging_steps. _TQDM_PROGRESS_RE = re.compile(r"\b(\d+)\s*/\s*(\d+)\s*\[") def _parse_training_log_dicts(text: str) -> List[Dict[str, Any]]: """Extract per-log-step rows from a captured TRL stdout log. TRL prints a Python dict-repr on each ``logging_steps`` boundary. We pair each dict with the most recent tqdm progress line so the plotted x-axis reflects ``global_step`` rather than dict-arrival order. Lines that do not parse cleanly are silently skipped. """ rows: List[Dict[str, Any]] = [] last_step: Optional[int] = None for raw in text.splitlines(): m = _TQDM_PROGRESS_RE.search(raw) if m: try: last_step = int(m.group(1)) except ValueError: pass continue s = raw.strip() if not (s.startswith("{") and s.endswith("}")): continue if "'loss'" not in s and "'reward'" not in s and "'kl'" not in s: continue try: d = ast.literal_eval(s) except (ValueError, SyntaxError): continue if not isinstance(d, dict): continue reward = ( d.get("reward") or d.get("rewards/mean") or d.get("rewards/reward_fn/mean") ) reward_std = ( d.get("reward_std") or d.get("rewards/std") or d.get("rewards/reward_fn/std") ) rows.append({ "step": last_step if last_step is not None else len(rows), "loss": d.get("loss"), "reward": reward, "reward_std": reward_std, "kl": d.get("kl"), "grad_norm": d.get("grad_norm"), "learning_rate": d.get("learning_rate"), "epoch": d.get("epoch"), "frac_reward_zero_std": d.get("frac_reward_zero_std"), "completions_mean_length": d.get("completions/mean_length"), "completions_clipped_ratio": d.get("completions/clipped_ratio"), }) return rows def _try_matplotlib(): try: import matplotlib # type: ignore matplotlib.use("Agg") import matplotlib.pyplot as plt # type: ignore return plt except Exception as exc: # pragma: no cover - plotting is best-effort logger.warning("matplotlib unavailable: %s", exc) return None def _png_bytes(fig) -> bytes: buf = io.BytesIO() fig.savefig(buf, format="png", dpi=140) return buf.getvalue() def _read_log_text() -> Optional[str]: if not LOG_FILE.exists(): return None try: return LOG_FILE.read_text(errors="replace") except OSError: return None def _synth_training_curve_png() -> Optional[bytes]: """Render a 2-panel reward/loss curve from the captured TRL stdout log.""" text = _read_log_text() if not text: return None rows = _parse_training_log_dicts(text) if not rows: return None plt = _try_matplotlib() if plt is None: return None steps = [r["step"] for r in rows] rewards = [(s, r["reward"]) for s, r in zip(steps, rows) if r["reward"] is not None] losses = [(s, r["loss"]) for s, r in zip(steps, rows) if r["loss"] is not None] fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) if rewards: axes[0].plot([x for x, _ in rewards], [y for _, y in rewards], lw=1.6, color="#1d4ed8") axes[0].set_ylabel("mean reward") axes[0].set_title( "CERNenv GRPO training — reward over steps " f"(synthesised from {len(rewards)} log events)" ) axes[0].grid(alpha=0.25) if losses: axes[1].plot([x for x, _ in losses], [y for _, y in losses], lw=1.6, color="#c026d3") axes[1].set_ylabel("GRPO loss") axes[1].set_xlabel("training step") axes[1].grid(alpha=0.25) fig.tight_layout() try: return _png_bytes(fig) finally: plt.close(fig) def _synth_reward_components_png() -> Optional[bytes]: """Best-effort reward-components view derived from TRL stdout. The unsloth callback writes a true terminal-vs-shaping split into reward_components.csv. The vanilla backend only emits aggregate reward in the TRL log dict, so here we fall back to plotting reward mean ± std (group dispersion) and KL on a second axis. This still surfaces the "watch dispersion, not just the mean" view the FAQ recommends — at least until a real callback writes a richer CSV. """ text = _read_log_text() if not text: return None rows = _parse_training_log_dicts(text) if not rows: return None plt = _try_matplotlib() if plt is None: return None steps = [r["step"] for r in rows] rmean = [r.get("reward") for r in rows] rstd = [r.get("reward_std") for r in rows] kls = [r.get("kl") for r in rows] fzero = [r.get("frac_reward_zero_std") for r in rows] clen = [r.get("completions_mean_length") for r in rows] fig, axes = plt.subplots(2, 1, figsize=(8, 6.5), sharex=True) band = [(s, m, sd) for s, m, sd in zip(steps, rmean, rstd) if m is not None] if band: sx = [b[0] for b in band] rm = [b[1] for b in band] rs = [b[2] if b[2] is not None else 0.0 for b in band] axes[0].plot(sx, rm, lw=2.0, color="#0f172a", label="reward (group mean)") axes[0].fill_between( sx, [m - s for m, s in zip(rm, rs)], [m + s for m, s in zip(rm, rs)], alpha=0.18, color="#1d4ed8", label="±1 std (group dispersion)", ) axes[0].set_ylabel("reward at logging step") axes[0].set_title( "CERNenv reward — group mean ± dispersion " "(stdout-derived; install EvidenceCallback for terminal vs shaping split)" ) axes[0].grid(alpha=0.25) axes[0].legend(loc="lower right", fontsize=9) kl_pts = [(s, k) for s, k in zip(steps, kls) if k is not None] if kl_pts: axes[1].plot([p[0] for p in kl_pts], [p[1] for p in kl_pts], lw=1.5, color="#9333ea", label="KL divergence") axes[1].set_ylabel("KL", color="#9333ea") fz_pts = [(s, f) for s, f in zip(steps, fzero) if f is not None] cl_pts = [(s, c) for s, c in zip(steps, clen) if c is not None] if fz_pts or cl_pts: ax2 = axes[1].twinx() if fz_pts: ax2.plot([p[0] for p in fz_pts], [p[1] for p in fz_pts], "o-", lw=1.0, ms=3, color="#ea580c", label="frac rollouts with zero-std (saturation)") ax2.set_ylim(-0.02, 1.05) if cl_pts: cmax = max(p[1] for p in cl_pts) or 1.0 ax2.plot([p[0] for p in cl_pts], [p[1] / cmax for p in cl_pts], "x:", lw=1.0, ms=4, color="#16a34a", label=f"completion mean length / {cmax:.0f}") ax2.set_ylabel("auxiliary (right axis, normalised)", color="#475569") ax2.legend(loc="upper right", fontsize=8) axes[1].set_xlabel("training step") axes[1].grid(alpha=0.25) fig.tight_layout() try: return _png_bytes(fig) finally: plt.close(fig) _SYNTH_HANDLERS = { "training_curve.png": _synth_training_curve_png, "reward_components.png": _synth_reward_components_png, } # ── FastAPI app ────────────────────────────────────────────────────────── app = FastAPI(title="CERNenv Trainer", version="0.1.0") _HTML = """\
GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment. Multi-GPU on Hugging Face Spaces.
Status: ?
Auto-updated as training runs. All artifacts are also saved to evidence/ and pushed to the model repo on the Hub.
| metric | pre | post | Δ |
|---|
loading…""" @app.get("/", response_class=HTMLResponse) def index() -> HTMLResponse: return HTMLResponse(_HTML) @app.get("/health") def health() -> Dict[str, str]: return {"status": "ok"} @app.get("/status") def status() -> JSONResponse: return JSONResponse(STATE.to_dict()) @app.get("/metrics") def metrics() -> JSONResponse: if METRICS_FILE.exists(): try: return JSONResponse(json.loads(METRICS_FILE.read_text())) except Exception: return JSONResponse({"error": "metrics file unreadable"}, status_code=500) return JSONResponse({"pre": None, "post": None, "delta": None}) @app.get("/sft_summary") def sft_summary() -> JSONResponse: """Return the SFT warm-start summary if it exists. Powers the dashboard's "Warm-start (SFT)" card: shows the final training loss, oracle success rate, and wall-clock duration once the SFT phase has written ``evidence/sft_summary.json``. """ path = EVIDENCE_DIR / "sft_summary.json" if path.exists(): try: return JSONResponse(json.loads(path.read_text())) except Exception: return JSONResponse({"error": "sft_summary unreadable"}, status_code=500) return JSONResponse({}, status_code=404) @app.get("/evidence") def evidence_index() -> JSONResponse: """List every evidence artifact currently on disk.""" files = [] if EVIDENCE_DIR.exists(): for p in sorted(EVIDENCE_DIR.iterdir()): if p.is_file(): files.append({ "name": p.name, "size": p.stat().st_size, "url": f"/evidence/{p.name}", }) return JSONResponse({"dir": str(EVIDENCE_DIR), "files": files}) @app.get("/evidence/{name}") def evidence_file(name: str): """Serve a single evidence artifact (PNG/CSV/JSON/MD) by filename. For ``training_curve.png`` and ``reward_components.png`` we fall back to on-demand synthesis from the captured TRL stdout log when the underlying file does not yet exist on disk — which is the normal state of affairs when the vanilla backend is running and no EvidenceCallback has had a chance to write the source CSV. """ if "/" in name or ".." in name: raise HTTPException(status_code=400, detail="invalid name") target = EVIDENCE_DIR / name if target.exists() and target.is_file(): return FileResponse(target) handler = _SYNTH_HANDLERS.get(name) if handler is not None: try: png = handler() except Exception as exc: # pragma: no cover - synthesis is best-effort logger.warning("on-demand synthesis of %s failed: %s", name, exc) png = None if png: return Response( content=png, media_type="image/png", headers={"Cache-Control": "no-store, max-age=0"}, ) raise HTTPException(status_code=404, detail=f"{name} not found") @app.get("/logs", response_class=PlainTextResponse) def logs(tail: int = 400) -> PlainTextResponse: if not LOG_FILE.exists(): return PlainTextResponse("") text = LOG_FILE.read_text() lines = text.splitlines() return PlainTextResponse("\n".join(lines[-max(tail, 1):])) @app.post("/train") async def train(request: Request) -> JSONResponse: """Start a training run. The request body (JSON) is merged into the global ``CONFIG`` for *this* run only, so future API-only triggers can flip ``sft_warmstart`` (or any other config key) without redeploying the Space. Unknown keys are accepted as-is — type coercion is the caller's responsibility. """ overrides: Dict[str, Any] = {} try: body = await request.body() if body: overrides = json.loads(body) if not isinstance(overrides, dict): raise ValueError("request body must be a JSON object") except (ValueError, json.JSONDecodeError) as exc: raise HTTPException(status_code=400, detail=f"bad request body: {exc}") cfg = dict(CONFIG) cfg.update(overrides) try: _start_training(cfg) except RuntimeError as exc: raise HTTPException(status_code=409, detail=str(exc)) return JSONResponse({"status": "started", "config": cfg}) @app.on_event("startup") def _maybe_autostart() -> None: if CONFIG["autostart"]: try: _start_training(dict(CONFIG)) logger.info("autostarted training run") except RuntimeError as exc: logger.warning("autostart skipped: %s", exc)