cernenv / space /training /app.py
anugrah55's picture
Update CERNenv Space
7103888 verified
"""FastAPI control panel for the CERNenv trainer Space.
Endpoints:
GET / → status page (HTML)
GET /status → JSON status of the current training run
GET /metrics → JSON snapshot of reward / success rate
GET /logs → tail of the training log
POST /train → start (or restart) a training run
GET /health → liveness probe
Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training
work runs in a background thread so the HTTP server stays responsive.
"""
from __future__ import annotations
import json
import logging
import os
import subprocess
import sys
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def _resolve_repo_root() -> Path:
env_root = os.environ.get("CERNENV_ROOT")
candidates = []
if env_root:
candidates.append(Path(env_root))
candidates.extend([
Path("/home/user/app"),
Path(__file__).resolve().parent.parent.parent,
])
for p in candidates:
try:
if p.exists():
return p.resolve()
except OSError:
continue
return candidates[-1].resolve()
REPO_ROOT = _resolve_repo_root()
LOG_DIR = REPO_ROOT / "training" / "runs"
try:
LOG_DIR.mkdir(parents=True, exist_ok=True)
except OSError as exc: # pragma: no cover - read-only filesystem fallback
logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc)
LOG_DIR = Path("/tmp/cernenv-runs")
LOG_DIR.mkdir(parents=True, exist_ok=True)
LOG_FILE = LOG_DIR / "training.log"
METRICS_FILE = REPO_ROOT / "training" / "plots" / "metrics_summary.json"
def _env(name: str, default: str) -> str:
return os.environ.get(name, default)
CONFIG = {
"model_name": _env("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct"),
"difficulty": _env("DIFFICULTY", "easy"),
"total_episodes": int(_env("TOTAL_EPISODES", "400")),
"max_steps": int(_env("MAX_STEPS", "18")),
"num_generations": int(_env("NUM_GENERATIONS", "4")),
"output_dir": _env("OUTPUT_DIR", "training/runs/unsloth-grpo"),
"hf_username": _env("HF_USERNAME", "YOUR_HF_USERNAME"),
"push_repo": _env(
"PUSH_REPO",
f"{_env('HF_USERNAME', 'YOUR_HF_USERNAME')}/cernenv-grpo-qwen2.5-3b",
),
"autostart": _env("AUTOSTART", "0") == "1",
}
# ── Run state ────────────────────────────────────────────────────────────
class RunState:
def __init__(self) -> None:
self.lock = threading.Lock()
self.thread: Optional[threading.Thread] = None
self.process: Optional[subprocess.Popen] = None
self.status: str = "idle" # idle | running | finished | failed
self.started_at: Optional[str] = None
self.finished_at: Optional[str] = None
self.last_error: Optional[str] = None
self.last_config: Dict[str, Any] = {}
def to_dict(self) -> Dict[str, Any]:
with self.lock:
return {
"status": self.status,
"started_at": self.started_at,
"finished_at": self.finished_at,
"last_error": self.last_error,
"last_config": self.last_config,
}
STATE = RunState()
# ── Training pipeline ────────────────────────────────────────────────────
def _stream_subprocess(cmd: list[str], log_handle) -> int:
log_handle.write(f"\n$ {' '.join(cmd)}\n")
log_handle.flush()
proc = subprocess.Popen(
cmd,
cwd=str(REPO_ROOT),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
universal_newlines=True,
env={**os.environ, "PYTHONPATH": str(REPO_ROOT)},
)
STATE.process = proc
assert proc.stdout is not None
for line in proc.stdout:
log_handle.write(line)
log_handle.flush()
rc = proc.wait()
log_handle.write(f"[exit code {rc}]\n")
log_handle.flush()
STATE.process = None
return rc
def _training_pipeline(config: Dict[str, Any]) -> None:
started = datetime.now(timezone.utc).isoformat()
with STATE.lock:
STATE.status = "running"
STATE.started_at = started
STATE.finished_at = None
STATE.last_error = None
STATE.last_config = dict(config)
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(LOG_FILE, "a") as log:
log.write(f"\n=== Training started {started} ===\n")
log.write(json.dumps(config, indent=2) + "\n")
log.flush()
try:
output_dir = config["output_dir"]
difficulty = config["difficulty"]
max_steps = str(config["max_steps"])
episodes = str(config["total_episodes"])
num_gens = str(config["num_generations"])
model_name = config["model_name"]
push_repo = config["push_repo"]
eval_pre = "training/runs/eval_pre_train.jsonl"
eval_post = "training/runs/eval_post_train.jsonl"
plots_dir = "training/plots"
log.write("\n--- baseline (heuristic / oracle / random) ---\n")
log.flush()
for agent in ("random", "heuristic", "oracle"):
_stream_subprocess(
[
sys.executable, "-m", "scripts.run_agent",
"--agent", agent, "--difficulty", difficulty,
"--episodes", "3", "--quiet",
],
log,
)
log.write("\n--- pre-train evaluation ---\n")
log.flush()
rc = _stream_subprocess(
[
sys.executable, "-m", "training.evaluate",
"--model_name", model_name,
"--difficulty", difficulty,
"--episodes", "16",
"--max_steps", max_steps,
"--tag", "pre_train",
"--out", eval_pre,
],
log,
)
if rc != 0:
raise RuntimeError(f"pre-train eval failed (rc={rc})")
log.write("\n--- GRPO training ---\n")
log.flush()
rc = _stream_subprocess(
[
sys.executable, "-m", "training.training_unsloth",
"--model_name", model_name,
"--difficulty", difficulty,
"--total_episodes", episodes,
"--max_steps", max_steps,
"--num_generations", num_gens,
"--output_dir", output_dir,
],
log,
)
if rc != 0:
raise RuntimeError(f"training failed (rc={rc})")
log.write("\n--- post-train evaluation ---\n")
log.flush()
rc = _stream_subprocess(
[
sys.executable, "-m", "training.evaluate",
"--model_name", model_name,
"--adapter_dir", output_dir,
"--difficulty", difficulty,
"--episodes", "16",
"--max_steps", max_steps,
"--tag", "post_train",
"--out", eval_post,
],
log,
)
if rc != 0:
raise RuntimeError(f"post-train eval failed (rc={rc})")
log.write("\n--- plots ---\n")
log.flush()
_stream_subprocess(
[
sys.executable, "-m", "training.plots",
"--pre", eval_pre,
"--post", eval_post,
"--out_dir", plots_dir,
],
log,
)
if os.environ.get("HF_TOKEN"):
log.write("\n--- push adapters to Hub ---\n")
log.flush()
_stream_subprocess(
[
sys.executable, "-m", "scripts.push_to_hub", "model",
"--adapter_dir", output_dir,
"--repo_id", push_repo,
"--base_model", model_name,
],
log,
)
else:
log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n")
log.flush()
with STATE.lock:
STATE.status = "finished"
except Exception as exc:
logger.exception("training pipeline failed")
with STATE.lock:
STATE.status = "failed"
STATE.last_error = str(exc)
finally:
finished = datetime.now(timezone.utc).isoformat()
log.write(f"\n=== Training ended {finished} ===\n")
log.flush()
with STATE.lock:
STATE.finished_at = finished
def _start_training(config: Dict[str, Any]) -> None:
with STATE.lock:
if STATE.status == "running":
raise RuntimeError("a training run is already in progress")
STATE.thread = threading.Thread(
target=_training_pipeline,
args=(config,),
name="cernenv-trainer",
daemon=True,
)
STATE.thread.start()
# ── FastAPI app ──────────────────────────────────────────────────────────
app = FastAPI(title="CERNenv Trainer", version="0.1.0")
_HTML = """\
<!doctype html>
<html lang=en>
<head>
<meta charset=utf-8>
<title>CERNenv Trainer</title>
<style>
body { font-family: ui-sans-serif, system-ui, sans-serif; margin: 2rem auto; max-width: 760px; color:#111 }
h1 { margin-bottom: 0 }
.muted { color:#666 }
pre { background:#0e1116; color:#e6edf3; padding:1rem; border-radius:6px; overflow-x:auto; max-height:50vh }
button { font-size:1rem; padding:.6rem 1rem; border-radius:6px; border:1px solid #888; background:#fff; cursor:pointer }
.pill { display:inline-block; padding:.1rem .5rem; border-radius:999px; background:#eef; color:#225 }
.ok { background:#dfd; color:#272 }
.fail { background:#fdd; color:#822 }
.run { background:#fdf6d8; color:#774 }
table { border-collapse:collapse; }
td { padding:.2rem .8rem .2rem 0; }
</style>
</head>
<body>
<h1>⚛️ CERNenv Trainer</h1>
<p class=muted>GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment.</p>
<h3>Status: <span id=status class=pill>?</span></h3>
<table id=meta></table>
<p>
<button onclick="startRun()">▶ Start training</button>
<button onclick="refresh()">↻ Refresh</button>
</p>
<h3>Logs (tail)</h3>
<pre id=logs>loading…</pre>
<script>
async function refresh() {
const s = await fetch('/status').then(r => r.json());
const pill = document.getElementById('status');
pill.textContent = s.status;
pill.className = 'pill ' + ({idle:'',running:'run',finished:'ok',failed:'fail'}[s.status] || '');
const meta = document.getElementById('meta');
meta.innerHTML = '';
for (const [k, v] of Object.entries({
started_at: s.started_at, finished_at: s.finished_at, error: s.last_error,
...(s.last_config || {}),
})) {
if (v == null || v === '') continue;
const tr = document.createElement('tr');
tr.innerHTML = `<td><b>${k}</b></td><td><code>${v}</code></td>`;
meta.appendChild(tr);
}
const logs = await fetch('/logs?tail=200').then(r => r.text());
document.getElementById('logs').textContent = logs || '(no logs yet)';
}
async function startRun() {
await fetch('/train', {method:'POST'});
setTimeout(refresh, 500);
}
refresh();
setInterval(refresh, 5000);
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
def index() -> HTMLResponse:
return HTMLResponse(_HTML)
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.get("/status")
def status() -> JSONResponse:
return JSONResponse(STATE.to_dict())
@app.get("/metrics")
def metrics() -> JSONResponse:
if METRICS_FILE.exists():
try:
return JSONResponse(json.loads(METRICS_FILE.read_text()))
except Exception:
return JSONResponse({"error": "metrics file unreadable"}, status_code=500)
return JSONResponse({"pre": None, "post": None})
@app.get("/logs", response_class=PlainTextResponse)
def logs(tail: int = 400) -> PlainTextResponse:
if not LOG_FILE.exists():
return PlainTextResponse("")
text = LOG_FILE.read_text()
lines = text.splitlines()
return PlainTextResponse("\n".join(lines[-max(tail, 1):]))
@app.post("/train")
def train() -> JSONResponse:
try:
_start_training(dict(CONFIG))
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
return JSONResponse({"status": "started", "config": CONFIG})
@app.on_event("startup")
def _maybe_autostart() -> None:
if CONFIG["autostart"]:
try:
_start_training(dict(CONFIG))
logger.info("autostarted training run")
except RuntimeError as exc:
logger.warning("autostart skipped: %s", exc)