siddeshwar-kagatikar commited on
Commit
2e14f6d
·
1 Parent(s): 4aca4f5

Make self-play training resilient to HF Space restarts

Browse files

The L40S full run was getting killed at ~step 84/120 of the generator
phase because uvicorn (PID 1) caught a SIGTERM from the Space and the
backgrounded training subshell died with the container, with no
recovery path. The answerer reward graph was empty because training
never reached the answerer phase.

scripts/space_start.sh: invert the lifecycle. uvicorn now runs in the
background (so HF healthchecks pass) and self-play training runs in
the foreground as the script's blocking child. A SIGTERM from the
platform now propagates to the trainer instead of cleanly tearing down
PID 1, and the EXIT trap stops uvicorn afterwards.

src/osint_env/training/self_play.py: _train_grpo_phase now resumes
work on every restart. It (1) skips the phase entirely if final_model
is already on disk, (2) auto-detects the latest local checkpoint-* and
calls trainer.train(resume_from_checkpoint=...) when full state
(optimizer.pt + trainer_state.json) is present, (3) falls back to
downloading the latest checkpoint for this phase from
OSINT_HF_CHECKPOINT_REPO_ID via huggingface_hub.snapshot_download and
warm-starts from those weights when only inference state is available.

config/self_play_training_hf_l40s_full.json: drop logging_steps from 5
to 1 (matches the smoke run so the W&B reward graph is dense for both
GeneratorRewardFunction and AnswererRewardFunction) and drop save_steps
from 120 to 30 so a mid-phase restart loses at most ~30 steps.

Made-with: Cursor

config/self_play_training_hf_l40s_full.json CHANGED
@@ -40,8 +40,8 @@
40
  "num_iterations": 1,
41
  "loss_type": "dapo",
42
  "scale_rewards": "group",
43
- "logging_steps": 5,
44
- "save_steps": 120,
45
  "save_total_limit": 1,
46
  "optim": "adamw_torch_fused",
47
  "bf16": true,
@@ -72,8 +72,8 @@
72
  "num_iterations": 1,
73
  "loss_type": "dapo",
74
  "scale_rewards": "group",
75
- "logging_steps": 5,
76
- "save_steps": 120,
77
  "save_total_limit": 1,
78
  "optim": "adamw_torch_fused",
79
  "bf16": true,
 
40
  "num_iterations": 1,
41
  "loss_type": "dapo",
42
  "scale_rewards": "group",
43
+ "logging_steps": 1,
44
+ "save_steps": 30,
45
  "save_total_limit": 1,
46
  "optim": "adamw_torch_fused",
47
  "bf16": true,
 
72
  "num_iterations": 1,
73
  "loss_type": "dapo",
74
  "scale_rewards": "group",
75
+ "logging_steps": 1,
76
+ "save_steps": 30,
77
  "save_total_limit": 1,
78
  "optim": "adamw_torch_fused",
79
  "bf16": true,
scripts/space_start.sh CHANGED
@@ -13,7 +13,37 @@ TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_l
13
  TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
14
  RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-1}"
15
  DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
16
- BACKGROUND_FLAG="${RUN_SELF_PLAY_BACKGROUND:-1}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  _train_self_play() {
19
  if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
@@ -27,7 +57,7 @@ _train_self_play() {
27
  # shellcheck disable=SC2086
28
  osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
29
  else
30
- echo "[space_start] Running self-play training."
31
  # shellcheck disable=SC2086
32
  osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
33
  fi
@@ -47,15 +77,20 @@ if _is_true "$RUN_FLAG"; then
47
  if [ -n "${OSINT_HF_CHECKPOINT_REPO_ID:-}" ]; then
48
  echo "[space_start] HF checkpoint repo: ${OSINT_HF_CHECKPOINT_REPO_ID}"
49
  fi
50
- if _is_true "$BACKGROUND_FLAG"; then
51
- echo "[space_start] Launching self-play in background so the Space API can stay online."
52
- _train_self_play &
53
- else
54
- _train_self_play
 
 
 
 
55
  fi
56
  else
57
  echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
 
 
 
 
58
  fi
59
-
60
- echo "[space_start] Starting API server."
61
- exec uvicorn server:app --host 0.0.0.0 --port "${PORT:-7860}"
 
13
  TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
14
  RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-1}"
15
  DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
16
+ SERVE_API_FLAG="${RUN_SPACE_API_SERVER:-1}"
17
+ PORT_VALUE="${PORT:-7860}"
18
+ UVICORN_LOG_PATH="${UVICORN_LOG_PATH:-/tmp/uvicorn.log}"
19
+
20
+ UVICORN_PID=""
21
+
22
+ _start_api_server_background() {
23
+ if ! _is_true "$SERVE_API_FLAG"; then
24
+ echo "[space_start] RUN_SPACE_API_SERVER disabled. Skipping API server."
25
+ return
26
+ fi
27
+ echo "[space_start] Starting API server in background on port ${PORT_VALUE} (logs: ${UVICORN_LOG_PATH})."
28
+ # API server runs in background ONLY for HF healthchecks. Training is the
29
+ # primary process. If HF infrastructure SIGTERMs the container we still
30
+ # want training to receive the signal and flush a final checkpoint, not
31
+ # to silently die because PID 1 (uvicorn previously) exited first.
32
+ uvicorn server:app --host 0.0.0.0 --port "${PORT_VALUE}" \
33
+ >"${UVICORN_LOG_PATH}" 2>&1 &
34
+ UVICORN_PID=$!
35
+ echo "[space_start] uvicorn pid=${UVICORN_PID}"
36
+ }
37
+
38
+ _stop_api_server() {
39
+ if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
40
+ echo "[space_start] Stopping uvicorn pid=${UVICORN_PID}."
41
+ kill "${UVICORN_PID}" 2>/dev/null || true
42
+ wait "${UVICORN_PID}" 2>/dev/null || true
43
+ fi
44
+ }
45
+
46
+ trap '_stop_api_server' EXIT INT TERM
47
 
48
  _train_self_play() {
49
  if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
 
57
  # shellcheck disable=SC2086
58
  osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
59
  else
60
+ echo "[space_start] Running self-play training (foreground)."
61
  # shellcheck disable=SC2086
62
  osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
63
  fi
 
77
  if [ -n "${OSINT_HF_CHECKPOINT_REPO_ID:-}" ]; then
78
  echo "[space_start] HF checkpoint repo: ${OSINT_HF_CHECKPOINT_REPO_ID}"
79
  fi
80
+ _start_api_server_background
81
+ # Run training in the FOREGROUND so the script (and therefore PID 1)
82
+ # blocks until training is finished. A graceful SIGTERM from HF will
83
+ # propagate to the training process via the shell's signal handling
84
+ # and the trap above will cleanly stop uvicorn afterwards.
85
+ _train_self_play
86
+ echo "[space_start] Training finished. Keeping API server alive for log inspection."
87
+ if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
88
+ wait "${UVICORN_PID}"
89
  fi
90
  else
91
  echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
92
+ _start_api_server_background
93
+ if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
94
+ wait "${UVICORN_PID}"
95
+ fi
96
  fi
 
 
 
src/osint_env/training/self_play.py CHANGED
@@ -179,6 +179,120 @@ def _require_training_stack() -> tuple[Any, Any, Any]:
179
  return Dataset, GRPOConfig, GRPOTrainer
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def _task_to_edge_json(task: TaskInstance) -> str:
184
  payload = [
@@ -731,6 +845,63 @@ def _train_grpo_phase(
731
  Dataset, GRPOConfig, GRPOTrainer = _require_training_stack()
732
 
733
  output_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  dataset = Dataset.from_list(rows)
735
  args = _safe_build_grpo_config(
736
  phase=phase,
@@ -753,8 +924,6 @@ def _train_grpo_phase(
753
  raise RuntimeError("Installed TRL version does not expose peft_config in GRPOTrainer.")
754
  trainer_kwargs["peft_config"] = _build_lora_config(lora)
755
 
756
- phase_label = str(run_name).strip() or str(output_dir.name)
757
- reward_class_name = type(reward_function).__name__
758
  print(
759
  f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
760
  f"max_steps={phase.max_steps}",
@@ -769,6 +938,17 @@ def _train_grpo_phase(
769
  f"per_device_train_batch_size={phase.per_device_train_batch_size}",
770
  flush=True,
771
  )
 
 
 
 
 
 
 
 
 
 
 
772
  strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
773
  trainer = GRPOTrainer(**trainer_kwargs)
774
  tracked_params = [
@@ -780,7 +960,10 @@ def _train_grpo_phase(
780
  name: float(param.detach().float().abs().mean().item())
781
  for name, param in tracked_params
782
  }
783
- train_output = trainer.train()
 
 
 
784
 
785
  final_dir = output_dir / "final_model"
786
  trainer.save_model(str(final_dir))
 
179
  return Dataset, GRPOConfig, GRPOTrainer
180
 
181
 
182
+ def _latest_local_checkpoint(output_dir: Path) -> Path | None:
183
+ if not output_dir.exists():
184
+ return None
185
+ candidates: list[tuple[int, Path]] = []
186
+ for path in output_dir.glob("checkpoint-*"):
187
+ if not path.is_dir():
188
+ continue
189
+ suffix = path.name.split("-", 1)[-1]
190
+ try:
191
+ step = int(suffix)
192
+ except ValueError:
193
+ continue
194
+ candidates.append((step, path))
195
+ if not candidates:
196
+ return None
197
+ candidates.sort(key=lambda item: item[0])
198
+ return candidates[-1][1]
199
+
200
+
201
+ def _final_model_already_present(output_dir: Path) -> bool:
202
+ final_dir = output_dir / "final_model"
203
+ if not final_dir.is_dir():
204
+ return False
205
+ safetensors = list(final_dir.glob("*.safetensors"))
206
+ legacy_bin = list(final_dir.glob("pytorch_model*.bin"))
207
+ return bool(safetensors or legacy_bin)
208
+
209
+
210
+ def _maybe_download_phase_checkpoints_from_hf(output_dir: Path, run_dir: Path) -> Path | None:
211
+ """If no local checkpoint exists, try to recover the latest checkpoint
212
+ for this phase from the HF Hub repo we already upload to. Returns the
213
+ local path of the restored ``checkpoint-*`` directory, or None.
214
+
215
+ Designed to make Space restarts non-destructive: training state is
216
+ pushed to ``OSINT_HF_CHECKPOINT_REPO_ID`` after every phase, so on a
217
+ fresh container we can pull it back and resume.
218
+ """
219
+ if _latest_local_checkpoint(output_dir) is not None:
220
+ return _latest_local_checkpoint(output_dir)
221
+ repo_id = _default_hf_checkpoint_repo_id(run_dir)
222
+ token = _resolve_hf_upload_token()
223
+ if not repo_id or not token:
224
+ return None
225
+
226
+ try:
227
+ from huggingface_hub import HfApi, snapshot_download
228
+ except ImportError:
229
+ return None
230
+
231
+ repo_type = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_TYPE", "model")).strip() or "model"
232
+ api = HfApi(token=token)
233
+ try:
234
+ files = list(api.list_repo_files(repo_id=repo_id, repo_type=repo_type))
235
+ except Exception as exc: # noqa: BLE001
236
+ print(
237
+ f"[self_play][resume] could not list files in {repo_id}: "
238
+ f"{type(exc).__name__}: {exc}",
239
+ flush=True,
240
+ )
241
+ return None
242
+
243
+ phase_prefix = _hf_relative_repo_path(output_dir, run_dir)
244
+ phase_prefix_clean = phase_prefix.strip("/") + "/"
245
+ candidate_steps: dict[int, list[str]] = {}
246
+ for remote_path in files:
247
+ if not remote_path.startswith(phase_prefix_clean):
248
+ continue
249
+ relative = remote_path[len(phase_prefix_clean) :]
250
+ parts = relative.split("/", 1)
251
+ if len(parts) < 2 or not parts[0].startswith("checkpoint-"):
252
+ continue
253
+ try:
254
+ step = int(parts[0].split("-", 1)[1])
255
+ except ValueError:
256
+ continue
257
+ candidate_steps.setdefault(step, []).append(remote_path)
258
+
259
+ if not candidate_steps:
260
+ return None
261
+
262
+ best_step = max(candidate_steps.keys())
263
+ target_local_dir = output_dir / f"checkpoint-{best_step}"
264
+ target_local_dir.mkdir(parents=True, exist_ok=True)
265
+ print(
266
+ f"[self_play][resume] downloading phase checkpoint from HF Hub: "
267
+ f"repo={repo_id} prefix={phase_prefix_clean}checkpoint-{best_step} "
268
+ f"-> {target_local_dir}",
269
+ flush=True,
270
+ )
271
+ try:
272
+ snapshot_download(
273
+ repo_id=repo_id,
274
+ repo_type=repo_type,
275
+ local_dir=str(output_dir),
276
+ allow_patterns=[f"{phase_prefix_clean}checkpoint-{best_step}/*"],
277
+ token=token,
278
+ )
279
+ downloaded_root = output_dir / phase_prefix_clean.rstrip("/") / f"checkpoint-{best_step}"
280
+ if downloaded_root.exists() and downloaded_root != target_local_dir:
281
+ for item in downloaded_root.iterdir():
282
+ dest = target_local_dir / item.name
283
+ if dest.exists():
284
+ continue
285
+ item.replace(dest)
286
+ return target_local_dir if target_local_dir.exists() else None
287
+ except Exception as exc: # noqa: BLE001
288
+ print(
289
+ f"[self_play][resume] failed to download checkpoint from HF Hub: "
290
+ f"{type(exc).__name__}: {exc}",
291
+ flush=True,
292
+ )
293
+ return None
294
+
295
+
296
 
297
  def _task_to_edge_json(task: TaskInstance) -> str:
298
  payload = [
 
845
  Dataset, GRPOConfig, GRPOTrainer = _require_training_stack()
846
 
847
  output_dir.mkdir(parents=True, exist_ok=True)
848
+
849
+ phase_label = str(run_name).strip() or str(output_dir.name)
850
+ reward_class_name = type(reward_function).__name__
851
+
852
+ # Output layout: <run_dir>/round_NNN/<phase_subdir>. Match the run_dir
853
+ # used by the corresponding HF upload helpers so resume paths line up
854
+ # with where checkpoints were written.
855
+ run_dir_for_resume = output_dir.parents[1] if len(output_dir.parents) >= 2 else output_dir.parent
856
+
857
+ if _final_model_already_present(output_dir):
858
+ final_dir = output_dir / "final_model"
859
+ print(
860
+ f"[self_play][resume] phase={phase_label} already has final_model at {final_dir}. "
861
+ f"Skipping retrain on Space restart.",
862
+ flush=True,
863
+ )
864
+ checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
865
+ return {
866
+ "model_path": str(final_dir),
867
+ "final_model_path": str(final_dir),
868
+ "phase_output_dir": str(output_dir),
869
+ "checkpoint_dirs": checkpoint_dirs,
870
+ "global_step": int(getattr(phase, "max_steps", 0) or 0),
871
+ "training_loss": 0.0,
872
+ "train_rows": len(rows),
873
+ "tuning_mode": str(tuning_mode).strip().lower() or "full",
874
+ "is_full_finetune": str(tuning_mode).strip().lower() != "lora",
875
+ "resumed_skipped": True,
876
+ }
877
+
878
+ resume_checkpoint = _latest_local_checkpoint(output_dir)
879
+ resume_is_full_state = bool(
880
+ resume_checkpoint is not None
881
+ and (resume_checkpoint / "optimizer.pt").exists()
882
+ and (resume_checkpoint / "trainer_state.json").exists()
883
+ )
884
+ if resume_checkpoint is None:
885
+ downloaded = _maybe_download_phase_checkpoints_from_hf(output_dir, run_dir_for_resume)
886
+ if downloaded is not None:
887
+ resume_checkpoint = downloaded
888
+ resume_is_full_state = bool(
889
+ (downloaded / "optimizer.pt").exists()
890
+ and (downloaded / "trainer_state.json").exists()
891
+ )
892
+
893
+ # If the resume checkpoint is weights-only (e.g. recovered from HF Hub
894
+ # which intentionally drops optimizer state to keep uploads small),
895
+ # warm-start the model from those weights and start a fresh trainer
896
+ # state. Better than restarting from the base model.
897
+ warm_start_from: Path | None = None
898
+ if resume_checkpoint is not None and not resume_is_full_state:
899
+ warm_start_from = resume_checkpoint
900
+ resume_checkpoint = None
901
+
902
+ if warm_start_from is not None:
903
+ model_name_or_path = str(warm_start_from)
904
+
905
  dataset = Dataset.from_list(rows)
906
  args = _safe_build_grpo_config(
907
  phase=phase,
 
924
  raise RuntimeError("Installed TRL version does not expose peft_config in GRPOTrainer.")
925
  trainer_kwargs["peft_config"] = _build_lora_config(lora)
926
 
 
 
927
  print(
928
  f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
929
  f"max_steps={phase.max_steps}",
 
938
  f"per_device_train_batch_size={phase.per_device_train_batch_size}",
939
  flush=True,
940
  )
941
+ if resume_checkpoint is not None:
942
+ print(
943
+ f"[self_play][resume] phase={phase_label} resuming (full state) from checkpoint={resume_checkpoint}",
944
+ flush=True,
945
+ )
946
+ elif warm_start_from is not None:
947
+ print(
948
+ f"[self_play][resume] phase={phase_label} warm-starting from weights only at {warm_start_from} "
949
+ f"(no optimizer state available)",
950
+ flush=True,
951
+ )
952
  strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
953
  trainer = GRPOTrainer(**trainer_kwargs)
954
  tracked_params = [
 
960
  name: float(param.detach().float().abs().mean().item())
961
  for name, param in tracked_params
962
  }
963
+ if resume_checkpoint is not None:
964
+ train_output = trainer.train(resume_from_checkpoint=str(resume_checkpoint))
965
+ else:
966
+ train_output = trainer.train()
967
 
968
  final_dir = output_dir / "final_model"
969
  trainer.save_model(str(final_dir))