Spaces:
Paused
Paused
sft+reward-fix: space/training/app.py
Browse files- space/training/app.py +130 -5
space/training/app.py
CHANGED
|
@@ -28,7 +28,7 @@ from datetime import datetime, timezone
|
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Any, Dict, List, Optional
|
| 30 |
|
| 31 |
-
from fastapi import FastAPI, HTTPException
|
| 32 |
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response
|
| 33 |
from fastapi.staticfiles import StaticFiles
|
| 34 |
|
|
@@ -97,6 +97,10 @@ def _detect_gpus() -> int:
|
|
| 97 |
_NUM_GPUS = _detect_gpus()
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
CONFIG = {
|
| 101 |
"training_backend": _env("TRAINING_BACKEND", "vanilla"),
|
| 102 |
"model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"),
|
|
@@ -119,6 +123,15 @@ CONFIG = {
|
|
| 119 |
f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m",
|
| 120 |
),
|
| 121 |
"autostart": _env("AUTOSTART", "0") == "1",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
}
|
| 123 |
|
| 124 |
|
|
@@ -177,8 +190,35 @@ def _stream_subprocess(cmd: list[str], log_handle) -> int:
|
|
| 177 |
return rc
|
| 178 |
|
| 179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
|
| 181 |
-
"""Compose the selected training launcher.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
backend = str(config.get("training_backend", "vanilla")).lower()
|
| 183 |
if backend == "vanilla":
|
| 184 |
python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
|
|
@@ -360,6 +400,31 @@ def _training_pipeline(config: Dict[str, Any]) -> None:
|
|
| 360 |
log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
|
| 361 |
log.flush()
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n")
|
| 364 |
log.flush()
|
| 365 |
rc = _stream_subprocess(_build_training_cmd(config), log)
|
|
@@ -813,6 +878,9 @@ _HTML = """\
|
|
| 813 |
<img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
|
| 814 |
<div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
|
| 815 |
</div>
|
|
|
|
|
|
|
|
|
|
| 816 |
</div>
|
| 817 |
|
| 818 |
<h2>Before / after metrics</h2>
|
|
@@ -888,6 +956,27 @@ async function refresh() {
|
|
| 888 |
probe.src = baseSrc + bust;
|
| 889 |
}
|
| 890 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
const logs = await fetch('/logs?tail=200').then(r => r.text());
|
| 892 |
document.getElementById('logs').textContent = logs || '(no logs yet)';
|
| 893 |
}
|
|
@@ -929,6 +1018,23 @@ def metrics() -> JSONResponse:
|
|
| 929 |
return JSONResponse({"pre": None, "post": None, "delta": None})
|
| 930 |
|
| 931 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 932 |
@app.get("/evidence")
|
| 933 |
def evidence_index() -> JSONResponse:
|
| 934 |
"""List every evidence artifact currently on disk."""
|
|
@@ -986,12 +1092,31 @@ def logs(tail: int = 400) -> PlainTextResponse:
|
|
| 986 |
|
| 987 |
|
| 988 |
@app.post("/train")
|
| 989 |
-
def train() -> JSONResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 990 |
try:
|
| 991 |
-
_start_training(
|
| 992 |
except RuntimeError as exc:
|
| 993 |
raise HTTPException(status_code=409, detail=str(exc))
|
| 994 |
-
return JSONResponse({"status": "started", "config":
|
| 995 |
|
| 996 |
|
| 997 |
@app.on_event("startup")
|
|
|
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Any, Dict, List, Optional
|
| 30 |
|
| 31 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 32 |
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response
|
| 33 |
from fastapi.staticfiles import StaticFiles
|
| 34 |
|
|
|
|
| 97 |
_NUM_GPUS = _detect_gpus()
|
| 98 |
|
| 99 |
|
| 100 |
+
def _bool_env(name: str, default: str) -> bool:
|
| 101 |
+
return _env(name, default).strip().lower() in ("1", "true", "yes", "on")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
CONFIG = {
|
| 105 |
"training_backend": _env("TRAINING_BACKEND", "vanilla"),
|
| 106 |
"model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"),
|
|
|
|
| 123 |
f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m",
|
| 124 |
),
|
| 125 |
"autostart": _env("AUTOSTART", "0") == "1",
|
| 126 |
+
# ── SFT warm-start phase (defeats v1's claim-avoidance reward hack
|
| 127 |
+
# by giving GRPO a non-zero prior over correct trajectories) ─────
|
| 128 |
+
"sft_warmstart": _bool_env("SFT_WARMSTART", "false"),
|
| 129 |
+
"sft_num_episodes": int(_env("SFT_NUM_EPISODES", "200")),
|
| 130 |
+
"sft_max_steps": int(_env("SFT_MAX_STEPS", "8")),
|
| 131 |
+
"sft_epochs": int(_env("SFT_EPOCHS", "1")),
|
| 132 |
+
"sft_lr": float(_env("SFT_LR", "1e-5")),
|
| 133 |
+
"sft_difficulty": _env("SFT_DIFFICULTY", "mixed"),
|
| 134 |
+
"sft_out_dir": _env("SFT_OUT_DIR", "runs/sft-warmstart"),
|
| 135 |
}
|
| 136 |
|
| 137 |
|
|
|
|
| 190 |
return rc
|
| 191 |
|
| 192 |
|
| 193 |
+
def _build_sft_warmstart_cmd(config: Dict[str, Any]) -> list[str]:
|
| 194 |
+
"""Compose the SFT-warm-start subprocess command.
|
| 195 |
+
|
| 196 |
+
Always uses the system Python so GRPO + SFT share the same
|
| 197 |
+
transformers + trl pin in space/training/requirements.txt.
|
| 198 |
+
"""
|
| 199 |
+
python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
|
| 200 |
+
return [
|
| 201 |
+
python_bin, "-m", "training.sft_warmstart",
|
| 202 |
+
"--out_dir", config["sft_out_dir"],
|
| 203 |
+
"--num_episodes", str(config["sft_num_episodes"]),
|
| 204 |
+
"--max_steps", str(config["sft_max_steps"]),
|
| 205 |
+
"--epochs", str(config["sft_epochs"]),
|
| 206 |
+
"--lr", str(config["sft_lr"]),
|
| 207 |
+
"--base_model", config["model_name"],
|
| 208 |
+
"--difficulty", config["sft_difficulty"],
|
| 209 |
+
"--evidence_dir", config["evidence_dir"],
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
|
| 214 |
+
"""Compose the selected training launcher.
|
| 215 |
+
|
| 216 |
+
When ``sft_warmstart`` is on, ``model_name`` is expected to already
|
| 217 |
+
have been overwritten with the SFT output directory by the caller
|
| 218 |
+
(``_training_pipeline``), so this function never has to know about
|
| 219 |
+
the SFT phase explicitly — it just trains GRPO from whatever path
|
| 220 |
+
is sitting in ``model_name``.
|
| 221 |
+
"""
|
| 222 |
backend = str(config.get("training_backend", "vanilla")).lower()
|
| 223 |
if backend == "vanilla":
|
| 224 |
python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
|
|
|
|
| 400 |
log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
|
| 401 |
log.flush()
|
| 402 |
|
| 403 |
+
if config.get("sft_warmstart"):
|
| 404 |
+
# Phase 1 — SFT warm-start. Produces a *full* causal-LM
|
| 405 |
+
# checkpoint at config['sft_out_dir'] (LoRA adapters are
|
| 406 |
+
# merged in by training/sft_warmstart.py) so we can hand
|
| 407 |
+
# it to GRPO as a drop-in --model_name.
|
| 408 |
+
sft_out = config["sft_out_dir"]
|
| 409 |
+
log.write(
|
| 410 |
+
f"\n--- SFT warm-start ({config['sft_num_episodes']} oracle "
|
| 411 |
+
f"episodes, epochs={config['sft_epochs']}, → {sft_out}) ---\n"
|
| 412 |
+
)
|
| 413 |
+
log.flush()
|
| 414 |
+
sft_rc = _stream_subprocess(_build_sft_warmstart_cmd(config), log)
|
| 415 |
+
if sft_rc != 0:
|
| 416 |
+
raise RuntimeError(f"SFT warm-start failed (rc={sft_rc})")
|
| 417 |
+
log.write(
|
| 418 |
+
f"\n[ok] SFT done; switching GRPO base model "
|
| 419 |
+
f"{config['model_name']} → {sft_out}\n"
|
| 420 |
+
)
|
| 421 |
+
log.flush()
|
| 422 |
+
config["model_name"] = sft_out
|
| 423 |
+
# Keep the *base* HF id around for evaluator commands —
|
| 424 |
+
# tokenizer files in the SFT directory are saved by the
|
| 425 |
+
# SFT script, but evaluation will load from this dir
|
| 426 |
+
# directly, so no further path bookkeeping is required.
|
| 427 |
+
|
| 428 |
log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n")
|
| 429 |
log.flush()
|
| 430 |
rc = _stream_subprocess(_build_training_cmd(config), log)
|
|
|
|
| 878 |
<img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
|
| 879 |
<div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
|
| 880 |
</div>
|
| 881 |
+
<div class=card><b>Warm-start (SFT)</b><br>
|
| 882 |
+
<div id=sft_card class=muted>(SFT_WARMSTART disabled — set the env var to enable)</div>
|
| 883 |
+
</div>
|
| 884 |
</div>
|
| 885 |
|
| 886 |
<h2>Before / after metrics</h2>
|
|
|
|
| 956 |
probe.src = baseSrc + bust;
|
| 957 |
}
|
| 958 |
|
| 959 |
+
// SFT warm-start card. /sft_summary returns 404 until the SFT phase
|
| 960 |
+
// has written evidence/sft_summary.json — when it does, render the
|
| 961 |
+
// headline numbers (final loss, oracle success rate, duration) so a
|
| 962 |
+
// reviewer can sanity-check the warm-start at a glance.
|
| 963 |
+
const sft_resp = await fetch('/sft_summary');
|
| 964 |
+
const sft_card = document.getElementById('sft_card');
|
| 965 |
+
if (sft_resp.ok) {
|
| 966 |
+
try {
|
| 967 |
+
const sft = await sft_resp.json();
|
| 968 |
+
sft_card.classList.remove('muted');
|
| 969 |
+
sft_card.innerHTML =
|
| 970 |
+
`<table>` +
|
| 971 |
+
`<tr><td><b>final loss</b></td><td><code>${fmt(sft.final_loss)}</code></td></tr>` +
|
| 972 |
+
`<tr><td><b>oracle success</b></td><td><code>${fmt(sft.oracle_success_rate)}</code></td></tr>` +
|
| 973 |
+
`<tr><td><b>transitions trained</b></td><td><code>${sft.num_train_rows ?? '–'}</code></td></tr>` +
|
| 974 |
+
`<tr><td><b>duration</b></td><td><code>${fmt(sft.duration_s)} s</code></td></tr>` +
|
| 975 |
+
`<tr><td><b>base → SFT dir</b></td><td><code>${sft.base_model} → ${sft.out_dir}</code></td></tr>` +
|
| 976 |
+
`</table>`;
|
| 977 |
+
} catch (e) { /* keep placeholder */ }
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
const logs = await fetch('/logs?tail=200').then(r => r.text());
|
| 981 |
document.getElementById('logs').textContent = logs || '(no logs yet)';
|
| 982 |
}
|
|
|
|
| 1018 |
return JSONResponse({"pre": None, "post": None, "delta": None})
|
| 1019 |
|
| 1020 |
|
| 1021 |
+
@app.get("/sft_summary")
|
| 1022 |
+
def sft_summary() -> JSONResponse:
|
| 1023 |
+
"""Return the SFT warm-start summary if it exists.
|
| 1024 |
+
|
| 1025 |
+
Powers the dashboard's "Warm-start (SFT)" card: shows the final
|
| 1026 |
+
training loss, oracle success rate, and wall-clock duration once
|
| 1027 |
+
the SFT phase has written ``evidence/sft_summary.json``.
|
| 1028 |
+
"""
|
| 1029 |
+
path = EVIDENCE_DIR / "sft_summary.json"
|
| 1030 |
+
if path.exists():
|
| 1031 |
+
try:
|
| 1032 |
+
return JSONResponse(json.loads(path.read_text()))
|
| 1033 |
+
except Exception:
|
| 1034 |
+
return JSONResponse({"error": "sft_summary unreadable"}, status_code=500)
|
| 1035 |
+
return JSONResponse({}, status_code=404)
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
@app.get("/evidence")
|
| 1039 |
def evidence_index() -> JSONResponse:
|
| 1040 |
"""List every evidence artifact currently on disk."""
|
|
|
|
| 1092 |
|
| 1093 |
|
| 1094 |
@app.post("/train")
|
| 1095 |
+
async def train(request: Request) -> JSONResponse:
|
| 1096 |
+
"""Start a training run.
|
| 1097 |
+
|
| 1098 |
+
The request body (JSON) is merged into the global ``CONFIG`` for
|
| 1099 |
+
*this* run only, so future API-only triggers can flip
|
| 1100 |
+
``sft_warmstart`` (or any other config key) without redeploying
|
| 1101 |
+
the Space. Unknown keys are accepted as-is — type coercion is the
|
| 1102 |
+
caller's responsibility.
|
| 1103 |
+
"""
|
| 1104 |
+
overrides: Dict[str, Any] = {}
|
| 1105 |
+
try:
|
| 1106 |
+
body = await request.body()
|
| 1107 |
+
if body:
|
| 1108 |
+
overrides = json.loads(body)
|
| 1109 |
+
if not isinstance(overrides, dict):
|
| 1110 |
+
raise ValueError("request body must be a JSON object")
|
| 1111 |
+
except (ValueError, json.JSONDecodeError) as exc:
|
| 1112 |
+
raise HTTPException(status_code=400, detail=f"bad request body: {exc}")
|
| 1113 |
+
cfg = dict(CONFIG)
|
| 1114 |
+
cfg.update(overrides)
|
| 1115 |
try:
|
| 1116 |
+
_start_training(cfg)
|
| 1117 |
except RuntimeError as exc:
|
| 1118 |
raise HTTPException(status_code=409, detail=str(exc))
|
| 1119 |
+
return JSONResponse({"status": "started", "config": cfg})
|
| 1120 |
|
| 1121 |
|
| 1122 |
@app.on_event("startup")
|