"""FastAPI control panel for the CERNenv trainer Space. Endpoints: GET / → status page (HTML) GET /status → JSON status of the current training run GET /metrics → JSON snapshot of reward / success rate GET /logs → tail of the training log POST /train → start (or restart) a training run GET /health → liveness probe Designed to run on a Hugging Face Space with `sdk: docker`. Heavy training work runs in a background thread so the HTTP server stays responsive. """ from __future__ import annotations import json import logging import os import subprocess import sys import threading import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) def _resolve_repo_root() -> Path: env_root = os.environ.get("CERNENV_ROOT") candidates = [] if env_root: candidates.append(Path(env_root)) candidates.extend([ Path("/home/user/app"), Path(__file__).resolve().parent.parent.parent, ]) for p in candidates: try: if p.exists(): return p.resolve() except OSError: continue return candidates[-1].resolve() REPO_ROOT = _resolve_repo_root() LOG_DIR = REPO_ROOT / "training" / "runs" try: LOG_DIR.mkdir(parents=True, exist_ok=True) except OSError as exc: # pragma: no cover - read-only filesystem fallback logger.warning("could not create %s (%s); using /tmp", LOG_DIR, exc) LOG_DIR = Path("/tmp/cernenv-runs") LOG_DIR.mkdir(parents=True, exist_ok=True) LOG_FILE = LOG_DIR / "training.log" METRICS_FILE = REPO_ROOT / "training" / "plots" / "metrics_summary.json" def _env(name: str, default: str) -> str: return os.environ.get(name, default) CONFIG = { "model_name": _env("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct"), "difficulty": _env("DIFFICULTY", "easy"), "total_episodes": int(_env("TOTAL_EPISODES", "400")), "max_steps": int(_env("MAX_STEPS", "18")), "num_generations": int(_env("NUM_GENERATIONS", "4")), "output_dir": _env("OUTPUT_DIR", "training/runs/unsloth-grpo"), "hf_username": _env("HF_USERNAME", "YOUR_HF_USERNAME"), "push_repo": _env( "PUSH_REPO", f"{_env('HF_USERNAME', 'YOUR_HF_USERNAME')}/cernenv-grpo-qwen2.5-3b", ), "autostart": _env("AUTOSTART", "0") == "1", } # ── Run state ──────────────────────────────────────────────────────────── class RunState: def __init__(self) -> None: self.lock = threading.Lock() self.thread: Optional[threading.Thread] = None self.process: Optional[subprocess.Popen] = None self.status: str = "idle" # idle | running | finished | failed self.started_at: Optional[str] = None self.finished_at: Optional[str] = None self.last_error: Optional[str] = None self.last_config: Dict[str, Any] = {} def to_dict(self) -> Dict[str, Any]: with self.lock: return { "status": self.status, "started_at": self.started_at, "finished_at": self.finished_at, "last_error": self.last_error, "last_config": self.last_config, } STATE = RunState() # ── Training pipeline ──────────────────────────────────────────────────── def _stream_subprocess(cmd: list[str], log_handle) -> int: log_handle.write(f"\n$ {' '.join(cmd)}\n") log_handle.flush() proc = subprocess.Popen( cmd, cwd=str(REPO_ROOT), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True, env={**os.environ, "PYTHONPATH": str(REPO_ROOT)}, ) STATE.process = proc assert proc.stdout is not None for line in proc.stdout: log_handle.write(line) log_handle.flush() rc = proc.wait() log_handle.write(f"[exit code {rc}]\n") log_handle.flush() STATE.process = None return rc def _training_pipeline(config: Dict[str, Any]) -> None: started = datetime.now(timezone.utc).isoformat() with STATE.lock: STATE.status = "running" STATE.started_at = started STATE.finished_at = None STATE.last_error = None STATE.last_config = dict(config) LOG_FILE.parent.mkdir(parents=True, exist_ok=True) with open(LOG_FILE, "a") as log: log.write(f"\n=== Training started {started} ===\n") log.write(json.dumps(config, indent=2) + "\n") log.flush() try: output_dir = config["output_dir"] difficulty = config["difficulty"] max_steps = str(config["max_steps"]) episodes = str(config["total_episodes"]) num_gens = str(config["num_generations"]) model_name = config["model_name"] push_repo = config["push_repo"] eval_pre = "training/runs/eval_pre_train.jsonl" eval_post = "training/runs/eval_post_train.jsonl" plots_dir = "training/plots" log.write("\n--- baseline (heuristic / oracle / random) ---\n") log.flush() for agent in ("random", "heuristic", "oracle"): _stream_subprocess( [ sys.executable, "-m", "scripts.run_agent", "--agent", agent, "--difficulty", difficulty, "--episodes", "3", "--quiet", ], log, ) log.write("\n--- pre-train evaluation ---\n") log.flush() rc = _stream_subprocess( [ sys.executable, "-m", "training.evaluate", "--model_name", model_name, "--difficulty", difficulty, "--episodes", "16", "--max_steps", max_steps, "--tag", "pre_train", "--out", eval_pre, ], log, ) if rc != 0: raise RuntimeError(f"pre-train eval failed (rc={rc})") log.write("\n--- GRPO training ---\n") log.flush() rc = _stream_subprocess( [ sys.executable, "-m", "training.training_unsloth", "--model_name", model_name, "--difficulty", difficulty, "--total_episodes", episodes, "--max_steps", max_steps, "--num_generations", num_gens, "--output_dir", output_dir, ], log, ) if rc != 0: raise RuntimeError(f"training failed (rc={rc})") log.write("\n--- post-train evaluation ---\n") log.flush() rc = _stream_subprocess( [ sys.executable, "-m", "training.evaluate", "--model_name", model_name, "--adapter_dir", output_dir, "--difficulty", difficulty, "--episodes", "16", "--max_steps", max_steps, "--tag", "post_train", "--out", eval_post, ], log, ) if rc != 0: raise RuntimeError(f"post-train eval failed (rc={rc})") log.write("\n--- plots ---\n") log.flush() _stream_subprocess( [ sys.executable, "-m", "training.plots", "--pre", eval_pre, "--post", eval_post, "--out_dir", plots_dir, ], log, ) if os.environ.get("HF_TOKEN"): log.write("\n--- push adapters to Hub ---\n") log.flush() _stream_subprocess( [ sys.executable, "-m", "scripts.push_to_hub", "model", "--adapter_dir", output_dir, "--repo_id", push_repo, "--base_model", model_name, ], log, ) else: log.write("\n[skip] HF_TOKEN not set — not pushing to Hub\n") log.flush() with STATE.lock: STATE.status = "finished" except Exception as exc: logger.exception("training pipeline failed") with STATE.lock: STATE.status = "failed" STATE.last_error = str(exc) finally: finished = datetime.now(timezone.utc).isoformat() log.write(f"\n=== Training ended {finished} ===\n") log.flush() with STATE.lock: STATE.finished_at = finished def _start_training(config: Dict[str, Any]) -> None: with STATE.lock: if STATE.status == "running": raise RuntimeError("a training run is already in progress") STATE.thread = threading.Thread( target=_training_pipeline, args=(config,), name="cernenv-trainer", daemon=True, ) STATE.thread.start() # ── FastAPI app ────────────────────────────────────────────────────────── app = FastAPI(title="CERNenv Trainer", version="0.1.0") _HTML = """\ CERNenv Trainer

⚛️ CERNenv Trainer

GRPO + Unsloth + LoRA on the CERNenv LHC discovery environment.

Status: ?

Logs (tail)

loading…
""" @app.get("/", response_class=HTMLResponse) def index() -> HTMLResponse: return HTMLResponse(_HTML) @app.get("/health") def health() -> Dict[str, str]: return {"status": "ok"} @app.get("/status") def status() -> JSONResponse: return JSONResponse(STATE.to_dict()) @app.get("/metrics") def metrics() -> JSONResponse: if METRICS_FILE.exists(): try: return JSONResponse(json.loads(METRICS_FILE.read_text())) except Exception: return JSONResponse({"error": "metrics file unreadable"}, status_code=500) return JSONResponse({"pre": None, "post": None}) @app.get("/logs", response_class=PlainTextResponse) def logs(tail: int = 400) -> PlainTextResponse: if not LOG_FILE.exists(): return PlainTextResponse("") text = LOG_FILE.read_text() lines = text.splitlines() return PlainTextResponse("\n".join(lines[-max(tail, 1):])) @app.post("/train") def train() -> JSONResponse: try: _start_training(dict(CONFIG)) except RuntimeError as exc: raise HTTPException(status_code=409, detail=str(exc)) return JSONResponse({"status": "started", "config": CONFIG}) @app.on_event("startup") def _maybe_autostart() -> None: if CONFIG["autostart"]: try: _start_training(dict(CONFIG)) logger.info("autostarted training run") except RuntimeError as exc: logger.warning("autostart skipped: %s", exc)