| """FastAPI control panel for the CERNenv trainer Space.
|
|
|
| Endpoints:
|
| GET / → status page (HTML)
|
| GET /status → JSON status of the current training run
|
| GET /metrics → JSON snapshot of reward / success rate
|
| GET /logs → tail of the training log
|
| POST /train → start (or restart) a training run
|
| GET /health → liveness probe
|
|
|
| Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training
|
| work runs in a background thread so the HTTP server stays responsive.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import json
|
| import logging
|
| import os
|
| import subprocess
|
| import sys
|
| import threading
|
| import time
|
| from datetime import datetime, timezone
|
| from pathlib import Path
|
| from typing import Any, Dict, Optional
|
|
|
| from fastapi import FastAPI, HTTPException
|
| from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
|
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def _resolve_repo_root() -> Path:
|
| env_root = os.environ.get("CERNENV_ROOT")
|
| candidates = []
|
| if env_root:
|
| candidates.append(Path(env_root))
|
| candidates.extend([
|
| Path("/home/user/app"),
|
| Path(__file__).resolve().parent.parent.parent,
|
| ])
|
| for p in candidates:
|
| try:
|
| if p.exists():
|
| return p.resolve()
|
| except OSError:
|
| continue
|
| return candidates[-1].resolve()
|
|
|
|
|
| REPO_ROOT = _resolve_repo_root()
|
| LOG_DIR = REPO_ROOT / "training" / "runs"
|
| try:
|
| LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| except OSError as exc:
|
| logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc)
|
| LOG_DIR = Path("/tmp/cernenv-runs")
|
| LOG_DIR.mkdir(parents=True, exist_ok=True)
|
| LOG_FILE = LOG_DIR / "training.log"
|
| METRICS_FILE = REPO_ROOT / "training" / "plots" / "metrics_summary.json"
|
|
|
|
|
| def _env(name: str, default: str) -> str:
|
| return os.environ.get(name, default)
|
|
|
|
|
| CONFIG = {
|
| "model_name": _env("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct"),
|
| "difficulty": _env("DIFFICULTY", "easy"),
|
| "total_episodes": int(_env("TOTAL_EPISODES", "400")),
|
| "max_steps": int(_env("MAX_STEPS", "18")),
|
| "num_generations": int(_env("NUM_GENERATIONS", "4")),
|
| "output_dir": _env("OUTPUT_DIR", "training/runs/unsloth-grpo"),
|
| "hf_username": _env("HF_USERNAME", "YOUR_HF_USERNAME"),
|
| "push_repo": _env(
|
| "PUSH_REPO",
|
| f"{_env('HF_USERNAME', 'YOUR_HF_USERNAME')}/cernenv-grpo-qwen2.5-3b",
|
| ),
|
| "autostart": _env("AUTOSTART", "0") == "1",
|
| }
|
|
|
|
|
|
|
|
|
|
|
| class RunState:
|
| def __init__(self) -> None:
|
| self.lock = threading.Lock()
|
| self.thread: Optional[threading.Thread] = None
|
| self.process: Optional[subprocess.Popen] = None
|
| self.status: str = "idle"
|
| self.started_at: Optional[str] = None
|
| self.finished_at: Optional[str] = None
|
| self.last_error: Optional[str] = None
|
| self.last_config: Dict[str, Any] = {}
|
|
|
| def to_dict(self) -> Dict[str, Any]:
|
| with self.lock:
|
| return {
|
| "status": self.status,
|
| "started_at": self.started_at,
|
| "finished_at": self.finished_at,
|
| "last_error": self.last_error,
|
| "last_config": self.last_config,
|
| }
|
|
|
|
|
| STATE = RunState()
|
|
|
|
|
|
|
|
|
|
|
| def _stream_subprocess(cmd: list[str], log_handle) -> int:
|
| log_handle.write(f"\n$ {' '.join(cmd)}\n")
|
| log_handle.flush()
|
| proc = subprocess.Popen(
|
| cmd,
|
| cwd=str(REPO_ROOT),
|
| stdout=subprocess.PIPE,
|
| stderr=subprocess.STDOUT,
|
| bufsize=1,
|
| universal_newlines=True,
|
| env={**os.environ, "PYTHONPATH": str(REPO_ROOT)},
|
| )
|
| STATE.process = proc
|
| assert proc.stdout is not None
|
| for line in proc.stdout:
|
| log_handle.write(line)
|
| log_handle.flush()
|
| rc = proc.wait()
|
| log_handle.write(f"[exit code {rc}]\n")
|
| log_handle.flush()
|
| STATE.process = None
|
| return rc
|
|
|
|
|
| def _training_pipeline(config: Dict[str, Any]) -> None:
|
| started = datetime.now(timezone.utc).isoformat()
|
| with STATE.lock:
|
| STATE.status = "running"
|
| STATE.started_at = started
|
| STATE.finished_at = None
|
| STATE.last_error = None
|
| STATE.last_config = dict(config)
|
|
|
| LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| with open(LOG_FILE, "a") as log:
|
| log.write(f"\n=== Training started {started} ===\n")
|
| log.write(json.dumps(config, indent=2) + "\n")
|
| log.flush()
|
| try:
|
| output_dir = config["output_dir"]
|
| difficulty = config["difficulty"]
|
| max_steps = str(config["max_steps"])
|
| episodes = str(config["total_episodes"])
|
| num_gens = str(config["num_generations"])
|
| model_name = config["model_name"]
|
| push_repo = config["push_repo"]
|
| eval_pre = "training/runs/eval_pre_train.jsonl"
|
| eval_post = "training/runs/eval_post_train.jsonl"
|
| plots_dir = "training/plots"
|
|
|
| log.write("\n--- baseline (heuristic / oracle / random) ---\n")
|
| log.flush()
|
| for agent in ("random", "heuristic", "oracle"):
|
| _stream_subprocess(
|
| [
|
| sys.executable, "-m", "scripts.run_agent",
|
| "--agent", agent, "--difficulty", difficulty,
|
| "--episodes", "3", "--quiet",
|
| ],
|
| log,
|
| )
|
|
|
| log.write("\n--- pre-train evaluation ---\n")
|
| log.flush()
|
| rc = _stream_subprocess(
|
| [
|
| sys.executable, "-m", "training.evaluate",
|
| "--model_name", model_name,
|
| "--difficulty", difficulty,
|
| "--episodes", "16",
|
| "--max_steps", max_steps,
|
| "--tag", "pre_train",
|
| "--out", eval_pre,
|
| ],
|
| log,
|
| )
|
| if rc != 0:
|
| raise RuntimeError(f"pre-train eval failed (rc={rc})")
|
|
|
| log.write("\n--- GRPO training ---\n")
|
| log.flush()
|
| rc = _stream_subprocess(
|
| [
|
| sys.executable, "-m", "training.training_unsloth",
|
| "--model_name", model_name,
|
| "--difficulty", difficulty,
|
| "--total_episodes", episodes,
|
| "--max_steps", max_steps,
|
| "--num_generations", num_gens,
|
| "--output_dir", output_dir,
|
| ],
|
| log,
|
| )
|
| if rc != 0:
|
| raise RuntimeError(f"training failed (rc={rc})")
|
|
|
| log.write("\n--- post-train evaluation ---\n")
|
| log.flush()
|
| rc = _stream_subprocess(
|
| [
|
| sys.executable, "-m", "training.evaluate",
|
| "--model_name", model_name,
|
| "--adapter_dir", output_dir,
|
| "--difficulty", difficulty,
|
| "--episodes", "16",
|
| "--max_steps", max_steps,
|
| "--tag", "post_train",
|
| "--out", eval_post,
|
| ],
|
| log,
|
| )
|
| if rc != 0:
|
| raise RuntimeError(f"post-train eval failed (rc={rc})")
|
|
|
| log.write("\n--- plots ---\n")
|
| log.flush()
|
| _stream_subprocess(
|
| [
|
| sys.executable, "-m", "training.plots",
|
| "--pre", eval_pre,
|
| "--post", eval_post,
|
| "--out_dir", plots_dir,
|
| ],
|
| log,
|
| )
|
|
|
| if os.environ.get("HF_TOKEN"):
|
| log.write("\n--- push adapters to Hub ---\n")
|
| log.flush()
|
| _stream_subprocess(
|
| [
|
| sys.executable, "-m", "scripts.push_to_hub", "model",
|
| "--adapter_dir", output_dir,
|
| "--repo_id", push_repo,
|
| "--base_model", model_name,
|
| ],
|
| log,
|
| )
|
| else:
|
| log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n")
|
| log.flush()
|
|
|
| with STATE.lock:
|
| STATE.status = "finished"
|
| except Exception as exc:
|
| logger.exception("training pipeline failed")
|
| with STATE.lock:
|
| STATE.status = "failed"
|
| STATE.last_error = str(exc)
|
| finally:
|
| finished = datetime.now(timezone.utc).isoformat()
|
| log.write(f"\n=== Training ended {finished} ===\n")
|
| log.flush()
|
| with STATE.lock:
|
| STATE.finished_at = finished
|
|
|
|
|
| def _start_training(config: Dict[str, Any]) -> None:
|
| with STATE.lock:
|
| if STATE.status == "running":
|
| raise RuntimeError("a training run is already in progress")
|
| STATE.thread = threading.Thread(
|
| target=_training_pipeline,
|
| args=(config,),
|
| name="cernenv-trainer",
|
| daemon=True,
|
| )
|
| STATE.thread.start()
|
|
|
|
|
|
|
|
|
|
|
| app = FastAPI(title="CERNenv Trainer", version="0.1.0")
|
|
|
|
|
| _HTML = """\
|
| <!doctype html>
|
| <html lang=en>
|
| <head>
|
| <meta charset=utf-8>
|
| <title>CERNenv Trainer</title>
|
| <style>
|
| body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto; max-width: 760px; color:#111 }
|
| h1 { margin-bottom: 0 }
|
| .muted { color:#666 }
|
| pre { background:#0e1116; color:#e6edf3; padding:1rem; border-radius:6px; overflow-x:auto; max-height:50vh }
|
| button { font-size:1rem; padding:.6rem 1rem; border-radius:6px; border:1px solid #888; background:#fff; cursor:pointer }
|
| .pill { display:inline-block; padding:.1rem .5rem; border-radius:999px; background:#eef; color:#225 }
|
| .ok { background:#dfd; color:#272 }
|
| .fail { background:#fdd; color:#822 }
|
| .run { background:#fdf6d8; color:#774 }
|
| table { border-collapse:collapse; }
|
| td { padding:.2rem .8rem .2rem 0; }
|
| </style>
|
| </head>
|
| <body>
|
| <h1>⚛️ CERNenv Trainer</h1>
|
| <p class=muted>GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment.</p>
|
|
|
| <h3>Status: <span id=status class=pill>?</span></h3>
|
| <table id=meta></table>
|
|
|
| <p>
|
| <button onclick="startRun()">▶ Start training</button>
|
| <button onclick="refresh()">↻ Refresh</button>
|
| </p>
|
|
|
| <h3>Logs (tail)</h3>
|
| <pre id=logs>loading…</pre>
|
|
|
| <script>
|
| async function refresh() {
|
| const s = await fetch('/status').then(r => r.json());
|
| const pill = document.getElementById('status');
|
| pill.textContent = s.status;
|
| pill.className = 'pill ' + ({idle:'',running:'run',finished:'ok',failed:'fail'}[s.status] || '');
|
|
|
| const meta = document.getElementById('meta');
|
| meta.innerHTML = '';
|
| for (const [k, v] of Object.entries({
|
| started_at: s.started_at, finished_at: s.finished_at, error: s.last_error,
|
| ...(s.last_config || {}),
|
| })) {
|
| if (v == null || v === '') continue;
|
| const tr = document.createElement('tr');
|
| tr.innerHTML = `<td><b>${k}</b></td><td><code>${v}</code></td>`;
|
| meta.appendChild(tr);
|
| }
|
|
|
| const logs = await fetch('/logs?tail=200').then(r => r.text());
|
| document.getElementById('logs').textContent = logs || '(no logs yet)';
|
| }
|
| async function startRun() {
|
| await fetch('/train', {method:'POST'});
|
| setTimeout(refresh, 500);
|
| }
|
| refresh();
|
| setInterval(refresh, 5000);
|
| </script>
|
| </body>
|
| </html>
|
| """
|
|
|
|
|
| @app.get("/", response_class=HTMLResponse)
|
| def index() -> HTMLResponse:
|
| return HTMLResponse(_HTML)
|
|
|
|
|
| @app.get("/health")
|
| def health() -> Dict[str, str]:
|
| return {"status": "ok"}
|
|
|
|
|
| @app.get("/status")
|
| def status() -> JSONResponse:
|
| return JSONResponse(STATE.to_dict())
|
|
|
|
|
| @app.get("/metrics")
|
| def metrics() -> JSONResponse:
|
| if METRICS_FILE.exists():
|
| try:
|
| return JSONResponse(json.loads(METRICS_FILE.read_text()))
|
| except Exception:
|
| return JSONResponse({"error": "metrics file unreadable"}, status_code=500)
|
| return JSONResponse({"pre": None, "post": None})
|
|
|
|
|
| @app.get("/logs", response_class=PlainTextResponse)
|
| def logs(tail: int = 400) -> PlainTextResponse:
|
| if not LOG_FILE.exists():
|
| return PlainTextResponse("")
|
| text = LOG_FILE.read_text()
|
| lines = text.splitlines()
|
| return PlainTextResponse("\n".join(lines[-max(tail, 1):]))
|
|
|
|
|
| @app.post("/train")
|
| def train() -> JSONResponse:
|
| try:
|
| _start_training(dict(CONFIG))
|
| except RuntimeError as exc:
|
| raise HTTPException(status_code=409, detail=str(exc))
|
| return JSONResponse({"status": "started", "config": CONFIG})
|
|
|
|
|
| @app.on_event("startup")
|
| def _maybe_autostart() -> None:
|
| if CONFIG["autostart"]:
|
| try:
|
| _start_training(dict(CONFIG))
|
| logger.info("autostarted training run")
|
| except RuntimeError as exc:
|
| logger.warning("autostart skipped: %s", exc)
|
|
|