anugrahhu commited on
Commit
c2c4674
·
verified ·
1 Parent(s): 2b97998

sft+reward-fix: space/training/app.py

Browse files
Files changed (1) hide show
  1. space/training/app.py +130 -5
space/training/app.py CHANGED
@@ -28,7 +28,7 @@ from datetime import datetime, timezone
28
  from pathlib import Path
29
  from typing import Any, Dict, List, Optional
30
 
31
- from fastapi import FastAPI, HTTPException
32
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response
33
  from fastapi.staticfiles import StaticFiles
34
 
@@ -97,6 +97,10 @@ def _detect_gpus() -> int:
97
  _NUM_GPUS = _detect_gpus()
98
 
99
 
 
 
 
 
100
  CONFIG = {
101
  "training_backend": _env("TRAINING_BACKEND", "vanilla"),
102
  "model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"),
@@ -119,6 +123,15 @@ CONFIG = {
119
  f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m",
120
  ),
121
  "autostart": _env("AUTOSTART", "0") == "1",
 
 
 
 
 
 
 
 
 
122
  }
123
 
124
 
@@ -177,8 +190,35 @@ def _stream_subprocess(cmd: list[str], log_handle) -> int:
177
  return rc
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
181
- """Compose the selected training launcher."""
 
 
 
 
 
 
 
182
  backend = str(config.get("training_backend", "vanilla")).lower()
183
  if backend == "vanilla":
184
  python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
@@ -360,6 +400,31 @@ def _training_pipeline(config: Dict[str, Any]) -> None:
360
  log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
361
  log.flush()
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n")
364
  log.flush()
365
  rc = _stream_subprocess(_build_training_cmd(config), log)
@@ -813,6 +878,9 @@ _HTML = """\
813
  <img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
814
  <div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
815
  </div>
 
 
 
816
  </div>
817
 
818
  <h2>Before / after metrics</h2>
@@ -888,6 +956,27 @@ async function refresh() {
888
  probe.src = baseSrc + bust;
889
  }
890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
891
  const logs = await fetch('/logs?tail=200').then(r => r.text());
892
  document.getElementById('logs').textContent = logs || '(no logs yet)';
893
  }
@@ -929,6 +1018,23 @@ def metrics() -> JSONResponse:
929
  return JSONResponse({"pre": None, "post": None, "delta": None})
930
 
931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
932
  @app.get("/evidence")
933
  def evidence_index() -> JSONResponse:
934
  """List every evidence artifact currently on disk."""
@@ -986,12 +1092,31 @@ def logs(tail: int = 400) -> PlainTextResponse:
986
 
987
 
988
  @app.post("/train")
989
- def train() -> JSONResponse:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
990
  try:
991
- _start_training(dict(CONFIG))
992
  except RuntimeError as exc:
993
  raise HTTPException(status_code=409, detail=str(exc))
994
- return JSONResponse({"status": "started", "config": CONFIG})
995
 
996
 
997
  @app.on_event("startup")
 
28
  from pathlib import Path
29
  from typing import Any, Dict, List, Optional
30
 
31
+ from fastapi import FastAPI, HTTPException, Request
32
  from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, Response
33
  from fastapi.staticfiles import StaticFiles
34
 
 
97
  _NUM_GPUS = _detect_gpus()
98
 
99
 
100
+ def _bool_env(name: str, default: str) -> bool:
101
+ return _env(name, default).strip().lower() in ("1", "true", "yes", "on")
102
+
103
+
104
  CONFIG = {
105
  "training_backend": _env("TRAINING_BACKEND", "vanilla"),
106
  "model_name": _env("MODEL_NAME", "HuggingFaceTB/SmolLM2-360M-Instruct"),
 
123
  f"{_env('HF_USERNAME', 'anugrahhu')}/cernenv-grpo-smollm2-360m",
124
  ),
125
  "autostart": _env("AUTOSTART", "0") == "1",
126
+ # ── SFT warm-start phase (defeats v1's claim-avoidance reward hack
127
+ # by giving GRPO a non-zero prior over correct trajectories) ─────
128
+ "sft_warmstart": _bool_env("SFT_WARMSTART", "false"),
129
+ "sft_num_episodes": int(_env("SFT_NUM_EPISODES", "200")),
130
+ "sft_max_steps": int(_env("SFT_MAX_STEPS", "8")),
131
+ "sft_epochs": int(_env("SFT_EPOCHS", "1")),
132
+ "sft_lr": float(_env("SFT_LR", "1e-5")),
133
+ "sft_difficulty": _env("SFT_DIFFICULTY", "mixed"),
134
+ "sft_out_dir": _env("SFT_OUT_DIR", "runs/sft-warmstart"),
135
  }
136
 
137
 
 
190
  return rc
191
 
192
 
193
+ def _build_sft_warmstart_cmd(config: Dict[str, Any]) -> list[str]:
194
+ """Compose the SFT-warm-start subprocess command.
195
+
196
+ Always uses the system Python so GRPO + SFT share the same
197
+ transformers + trl pin in space/training/requirements.txt.
198
+ """
199
+ python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
200
+ return [
201
+ python_bin, "-m", "training.sft_warmstart",
202
+ "--out_dir", config["sft_out_dir"],
203
+ "--num_episodes", str(config["sft_num_episodes"]),
204
+ "--max_steps", str(config["sft_max_steps"]),
205
+ "--epochs", str(config["sft_epochs"]),
206
+ "--lr", str(config["sft_lr"]),
207
+ "--base_model", config["model_name"],
208
+ "--difficulty", config["sft_difficulty"],
209
+ "--evidence_dir", config["evidence_dir"],
210
+ ]
211
+
212
+
213
  def _build_training_cmd(config: Dict[str, Any]) -> list[str]:
214
+ """Compose the selected training launcher.
215
+
216
+ When ``sft_warmstart`` is on, ``model_name`` is expected to already
217
+ have been overwritten with the SFT output directory by the caller
218
+ (``_training_pipeline``), so this function never has to know about
219
+ the SFT phase explicitly — it just trains GRPO from whatever path
220
+ is sitting in ``model_name``.
221
+ """
222
  backend = str(config.get("training_backend", "vanilla")).lower()
223
  if backend == "vanilla":
224
  python_bin = "/usr/local/bin/python" if Path("/usr/local/bin/python").exists() else sys.executable
 
400
  log.write(f"\n[warn] pre-train eval failed (rc={rc}); continuing without baseline\n")
401
  log.flush()
402
 
403
+ if config.get("sft_warmstart"):
404
+ # Phase 1 — SFT warm-start. Produces a *full* causal-LM
405
+ # checkpoint at config['sft_out_dir'] (LoRA adapters are
406
+ # merged in by training/sft_warmstart.py) so we can hand
407
+ # it to GRPO as a drop-in --model_name.
408
+ sft_out = config["sft_out_dir"]
409
+ log.write(
410
+ f"\n--- SFT warm-start ({config['sft_num_episodes']} oracle "
411
+ f"episodes, epochs={config['sft_epochs']}, → {sft_out}) ---\n"
412
+ )
413
+ log.flush()
414
+ sft_rc = _stream_subprocess(_build_sft_warmstart_cmd(config), log)
415
+ if sft_rc != 0:
416
+ raise RuntimeError(f"SFT warm-start failed (rc={sft_rc})")
417
+ log.write(
418
+ f"\n[ok] SFT done; switching GRPO base model "
419
+ f"{config['model_name']} → {sft_out}\n"
420
+ )
421
+ log.flush()
422
+ config["model_name"] = sft_out
423
+ # Keep the *base* HF id around for evaluator commands —
424
+ # tokenizer files in the SFT directory are saved by the
425
+ # SFT script, but evaluation will load from this dir
426
+ # directly, so no further path bookkeeping is required.
427
+
428
  log.write(f"\n--- GRPO training ({backend}, {config['num_gpus']} GPU process(es)) ---\n")
429
  log.flush()
430
  rc = _stream_subprocess(_build_training_cmd(config), log)
 
878
  <img id=dist src="/evidence/reward_distribution.png" onerror="this.style.display='none'">
879
  <div id=dist_missing class=muted style="display:none">(generated after post-train eval)</div>
880
  </div>
881
+ <div class=card><b>Warm-start (SFT)</b><br>
882
+ <div id=sft_card class=muted>(SFT_WARMSTART disabled — set the env var to enable)</div>
883
+ </div>
884
  </div>
885
 
886
  <h2>Before / after metrics</h2>
 
956
  probe.src = baseSrc + bust;
957
  }
958
 
959
+ // SFT warm-start card. /sft_summary returns 404 until the SFT phase
960
+ // has written evidence/sft_summary.json — when it does, render the
961
+ // headline numbers (final loss, oracle success rate, duration) so a
962
+ // reviewer can sanity-check the warm-start at a glance.
963
+ const sft_resp = await fetch('/sft_summary');
964
+ const sft_card = document.getElementById('sft_card');
965
+ if (sft_resp.ok) {
966
+ try {
967
+ const sft = await sft_resp.json();
968
+ sft_card.classList.remove('muted');
969
+ sft_card.innerHTML =
970
+ `<table>` +
971
+ `<tr><td><b>final loss</b></td><td><code>${fmt(sft.final_loss)}</code></td></tr>` +
972
+ `<tr><td><b>oracle success</b></td><td><code>${fmt(sft.oracle_success_rate)}</code></td></tr>` +
973
+ `<tr><td><b>transitions trained</b></td><td><code>${sft.num_train_rows ?? '–'}</code></td></tr>` +
974
+ `<tr><td><b>duration</b></td><td><code>${fmt(sft.duration_s)} s</code></td></tr>` +
975
+ `<tr><td><b>base → SFT dir</b></td><td><code>${sft.base_model} → ${sft.out_dir}</code></td></tr>` +
976
+ `</table>`;
977
+ } catch (e) { /* keep placeholder */ }
978
+ }
979
+
980
  const logs = await fetch('/logs?tail=200').then(r => r.text());
981
  document.getElementById('logs').textContent = logs || '(no logs yet)';
982
  }
 
1018
  return JSONResponse({"pre": None, "post": None, "delta": None})
1019
 
1020
 
1021
+ @app.get("/sft_summary")
1022
+ def sft_summary() -> JSONResponse:
1023
+ """Return the SFT warm-start summary if it exists.
1024
+
1025
+ Powers the dashboard's "Warm-start (SFT)" card: shows the final
1026
+ training loss, oracle success rate, and wall-clock duration once
1027
+ the SFT phase has written ``evidence/sft_summary.json``.
1028
+ """
1029
+ path = EVIDENCE_DIR / "sft_summary.json"
1030
+ if path.exists():
1031
+ try:
1032
+ return JSONResponse(json.loads(path.read_text()))
1033
+ except Exception:
1034
+ return JSONResponse({"error": "sft_summary unreadable"}, status_code=500)
1035
+ return JSONResponse({}, status_code=404)
1036
+
1037
+
1038
  @app.get("/evidence")
1039
  def evidence_index() -> JSONResponse:
1040
  """List every evidence artifact currently on disk."""
 
1092
 
1093
 
1094
  @app.post("/train")
1095
+ async def train(request: Request) -> JSONResponse:
1096
+ """Start a training run.
1097
+
1098
+ The request body (JSON) is merged into the global ``CONFIG`` for
1099
+ *this* run only, so future API-only triggers can flip
1100
+ ``sft_warmstart`` (or any other config key) without redeploying
1101
+ the Space. Unknown keys are accepted as-is — type coercion is the
1102
+ caller's responsibility.
1103
+ """
1104
+ overrides: Dict[str, Any] = {}
1105
+ try:
1106
+ body = await request.body()
1107
+ if body:
1108
+ overrides = json.loads(body)
1109
+ if not isinstance(overrides, dict):
1110
+ raise ValueError("request body must be a JSON object")
1111
+ except (ValueError, json.JSONDecodeError) as exc:
1112
+ raise HTTPException(status_code=400, detail=f"bad request body: {exc}")
1113
+ cfg = dict(CONFIG)
1114
+ cfg.update(overrides)
1115
  try:
1116
+ _start_training(cfg)
1117
  except RuntimeError as exc:
1118
  raise HTTPException(status_code=409, detail=str(exc))
1119
+ return JSONResponse({"status": "started", "config": cfg})
1120
 
1121
 
1122
  @app.on_event("startup")