Spaces:
Paused
Make self-play training resilient to HF Space restarts
Browse filesThe L40S full run was getting killed at ~step 84/120 of the generator
phase because uvicorn (PID 1) caught a SIGTERM from the Space and the
backgrounded training subshell died with the container, with no
recovery path. The answerer reward graph was empty because training
never reached the answerer phase.
scripts/space_start.sh: invert the lifecycle. uvicorn now runs in the
background (so HF healthchecks pass) and self-play training runs in
the foreground as the script's blocking child. A SIGTERM from the
platform now propagates to the trainer instead of cleanly tearing down
PID 1, and the EXIT trap stops uvicorn afterwards.
src/osint_env/training/self_play.py: _train_grpo_phase now resumes
work on every restart. It (1) skips the phase entirely if final_model
is already on disk, (2) auto-detects the latest local checkpoint-* and
calls trainer.train(resume_from_checkpoint=...) when full state
(optimizer.pt + trainer_state.json) is present, (3) falls back to
downloading the latest checkpoint for this phase from
OSINT_HF_CHECKPOINT_REPO_ID via huggingface_hub.snapshot_download and
warm-starts from those weights when only inference state is available.
config/self_play_training_hf_l40s_full.json: drop logging_steps from 5
to 1 (matches the smoke run so the W&B reward graph is dense for both
GeneratorRewardFunction and AnswererRewardFunction) and drop save_steps
from 120 to 30 so a mid-phase restart loses at most ~30 steps.
Made-with: Cursor
|
@@ -40,8 +40,8 @@
|
|
| 40 |
"num_iterations": 1,
|
| 41 |
"loss_type": "dapo",
|
| 42 |
"scale_rewards": "group",
|
| 43 |
-
"logging_steps":
|
| 44 |
-
"save_steps":
|
| 45 |
"save_total_limit": 1,
|
| 46 |
"optim": "adamw_torch_fused",
|
| 47 |
"bf16": true,
|
|
@@ -72,8 +72,8 @@
|
|
| 72 |
"num_iterations": 1,
|
| 73 |
"loss_type": "dapo",
|
| 74 |
"scale_rewards": "group",
|
| 75 |
-
"logging_steps":
|
| 76 |
-
"save_steps":
|
| 77 |
"save_total_limit": 1,
|
| 78 |
"optim": "adamw_torch_fused",
|
| 79 |
"bf16": true,
|
|
|
|
| 40 |
"num_iterations": 1,
|
| 41 |
"loss_type": "dapo",
|
| 42 |
"scale_rewards": "group",
|
| 43 |
+
"logging_steps": 1,
|
| 44 |
+
"save_steps": 30,
|
| 45 |
"save_total_limit": 1,
|
| 46 |
"optim": "adamw_torch_fused",
|
| 47 |
"bf16": true,
|
|
|
|
| 72 |
"num_iterations": 1,
|
| 73 |
"loss_type": "dapo",
|
| 74 |
"scale_rewards": "group",
|
| 75 |
+
"logging_steps": 1,
|
| 76 |
+
"save_steps": 30,
|
| 77 |
"save_total_limit": 1,
|
| 78 |
"optim": "adamw_torch_fused",
|
| 79 |
"bf16": true,
|
|
@@ -13,7 +13,37 @@ TRAIN_CONFIG_PATH="${TRAIN_SELF_PLAY_CONFIG_PATH:-config/self_play_training_hf_l
|
|
| 13 |
TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
|
| 14 |
RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-1}"
|
| 15 |
DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
_train_self_play() {
|
| 19 |
if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
|
|
@@ -27,7 +57,7 @@ _train_self_play() {
|
|
| 27 |
# shellcheck disable=SC2086
|
| 28 |
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
|
| 29 |
else
|
| 30 |
-
echo "[space_start] Running self-play training."
|
| 31 |
# shellcheck disable=SC2086
|
| 32 |
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
|
| 33 |
fi
|
|
@@ -47,15 +77,20 @@ if _is_true "$RUN_FLAG"; then
|
|
| 47 |
if [ -n "${OSINT_HF_CHECKPOINT_REPO_ID:-}" ]; then
|
| 48 |
echo "[space_start] HF checkpoint repo: ${OSINT_HF_CHECKPOINT_REPO_ID}"
|
| 49 |
fi
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
fi
|
| 56 |
else
|
| 57 |
echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
fi
|
| 59 |
-
|
| 60 |
-
echo "[space_start] Starting API server."
|
| 61 |
-
exec uvicorn server:app --host 0.0.0.0 --port "${PORT:-7860}"
|
|
|
|
| 13 |
TRAIN_OUTPUT_DIR="${TRAIN_SELF_PLAY_OUTPUT_DIR:-}"
|
| 14 |
RUN_FLAG="${RUN_SELF_PLAY_TRAINING:-1}"
|
| 15 |
DRY_RUN_FLAG="${RUN_SELF_PLAY_DRY_RUN:-0}"
|
| 16 |
+
SERVE_API_FLAG="${RUN_SPACE_API_SERVER:-1}"
|
| 17 |
+
PORT_VALUE="${PORT:-7860}"
|
| 18 |
+
UVICORN_LOG_PATH="${UVICORN_LOG_PATH:-/tmp/uvicorn.log}"
|
| 19 |
+
|
| 20 |
+
UVICORN_PID=""
|
| 21 |
+
|
| 22 |
+
_start_api_server_background() {
|
| 23 |
+
if ! _is_true "$SERVE_API_FLAG"; then
|
| 24 |
+
echo "[space_start] RUN_SPACE_API_SERVER disabled. Skipping API server."
|
| 25 |
+
return
|
| 26 |
+
fi
|
| 27 |
+
echo "[space_start] Starting API server in background on port ${PORT_VALUE} (logs: ${UVICORN_LOG_PATH})."
|
| 28 |
+
# API server runs in background ONLY for HF healthchecks. Training is the
|
| 29 |
+
# primary process. If HF infrastructure SIGTERMs the container we still
|
| 30 |
+
# want training to receive the signal and flush a final checkpoint, not
|
| 31 |
+
# to silently die because PID 1 (uvicorn previously) exited first.
|
| 32 |
+
uvicorn server:app --host 0.0.0.0 --port "${PORT_VALUE}" \
|
| 33 |
+
>"${UVICORN_LOG_PATH}" 2>&1 &
|
| 34 |
+
UVICORN_PID=$!
|
| 35 |
+
echo "[space_start] uvicorn pid=${UVICORN_PID}"
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
_stop_api_server() {
|
| 39 |
+
if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
|
| 40 |
+
echo "[space_start] Stopping uvicorn pid=${UVICORN_PID}."
|
| 41 |
+
kill "${UVICORN_PID}" 2>/dev/null || true
|
| 42 |
+
wait "${UVICORN_PID}" 2>/dev/null || true
|
| 43 |
+
fi
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
trap '_stop_api_server' EXIT INT TERM
|
| 47 |
|
| 48 |
_train_self_play() {
|
| 49 |
if [ -n "${TRAIN_OUTPUT_DIR}" ]; then
|
|
|
|
| 57 |
# shellcheck disable=SC2086
|
| 58 |
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG} --dry-run
|
| 59 |
else
|
| 60 |
+
echo "[space_start] Running self-play training (foreground)."
|
| 61 |
# shellcheck disable=SC2086
|
| 62 |
osint-env train-self-play --config "${ENV_CONFIG_PATH}" --train-config "${TRAIN_CONFIG_PATH}" ${OUTPUT_ARG}
|
| 63 |
fi
|
|
|
|
| 77 |
if [ -n "${OSINT_HF_CHECKPOINT_REPO_ID:-}" ]; then
|
| 78 |
echo "[space_start] HF checkpoint repo: ${OSINT_HF_CHECKPOINT_REPO_ID}"
|
| 79 |
fi
|
| 80 |
+
_start_api_server_background
|
| 81 |
+
# Run training in the FOREGROUND so the script (and therefore PID 1)
|
| 82 |
+
# blocks until training is finished. A graceful SIGTERM from HF will
|
| 83 |
+
# propagate to the training process via the shell's signal handling
|
| 84 |
+
# and the trap above will cleanly stop uvicorn afterwards.
|
| 85 |
+
_train_self_play
|
| 86 |
+
echo "[space_start] Training finished. Keeping API server alive for log inspection."
|
| 87 |
+
if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
|
| 88 |
+
wait "${UVICORN_PID}"
|
| 89 |
fi
|
| 90 |
else
|
| 91 |
echo "[space_start] RUN_SELF_PLAY_TRAINING disabled. Skipping self-play run."
|
| 92 |
+
_start_api_server_background
|
| 93 |
+
if [ -n "${UVICORN_PID}" ] && kill -0 "${UVICORN_PID}" 2>/dev/null; then
|
| 94 |
+
wait "${UVICORN_PID}"
|
| 95 |
+
fi
|
| 96 |
fi
|
|
|
|
|
|
|
|
|
|
@@ -179,6 +179,120 @@ def _require_training_stack() -> tuple[Any, Any, Any]:
|
|
| 179 |
return Dataset, GRPOConfig, GRPOTrainer
|
| 180 |
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
def _task_to_edge_json(task: TaskInstance) -> str:
|
| 184 |
payload = [
|
|
@@ -731,6 +845,63 @@ def _train_grpo_phase(
|
|
| 731 |
Dataset, GRPOConfig, GRPOTrainer = _require_training_stack()
|
| 732 |
|
| 733 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
dataset = Dataset.from_list(rows)
|
| 735 |
args = _safe_build_grpo_config(
|
| 736 |
phase=phase,
|
|
@@ -753,8 +924,6 @@ def _train_grpo_phase(
|
|
| 753 |
raise RuntimeError("Installed TRL version does not expose peft_config in GRPOTrainer.")
|
| 754 |
trainer_kwargs["peft_config"] = _build_lora_config(lora)
|
| 755 |
|
| 756 |
-
phase_label = str(run_name).strip() or str(output_dir.name)
|
| 757 |
-
reward_class_name = type(reward_function).__name__
|
| 758 |
print(
|
| 759 |
f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
|
| 760 |
f"max_steps={phase.max_steps}",
|
|
@@ -769,6 +938,17 @@ def _train_grpo_phase(
|
|
| 769 |
f"per_device_train_batch_size={phase.per_device_train_batch_size}",
|
| 770 |
flush=True,
|
| 771 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
|
| 773 |
trainer = GRPOTrainer(**trainer_kwargs)
|
| 774 |
tracked_params = [
|
|
@@ -780,7 +960,10 @@ def _train_grpo_phase(
|
|
| 780 |
name: float(param.detach().float().abs().mean().item())
|
| 781 |
for name, param in tracked_params
|
| 782 |
}
|
| 783 |
-
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
final_dir = output_dir / "final_model"
|
| 786 |
trainer.save_model(str(final_dir))
|
|
|
|
| 179 |
return Dataset, GRPOConfig, GRPOTrainer
|
| 180 |
|
| 181 |
|
| 182 |
+
def _latest_local_checkpoint(output_dir: Path) -> Path | None:
|
| 183 |
+
if not output_dir.exists():
|
| 184 |
+
return None
|
| 185 |
+
candidates: list[tuple[int, Path]] = []
|
| 186 |
+
for path in output_dir.glob("checkpoint-*"):
|
| 187 |
+
if not path.is_dir():
|
| 188 |
+
continue
|
| 189 |
+
suffix = path.name.split("-", 1)[-1]
|
| 190 |
+
try:
|
| 191 |
+
step = int(suffix)
|
| 192 |
+
except ValueError:
|
| 193 |
+
continue
|
| 194 |
+
candidates.append((step, path))
|
| 195 |
+
if not candidates:
|
| 196 |
+
return None
|
| 197 |
+
candidates.sort(key=lambda item: item[0])
|
| 198 |
+
return candidates[-1][1]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _final_model_already_present(output_dir: Path) -> bool:
|
| 202 |
+
final_dir = output_dir / "final_model"
|
| 203 |
+
if not final_dir.is_dir():
|
| 204 |
+
return False
|
| 205 |
+
safetensors = list(final_dir.glob("*.safetensors"))
|
| 206 |
+
legacy_bin = list(final_dir.glob("pytorch_model*.bin"))
|
| 207 |
+
return bool(safetensors or legacy_bin)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def _maybe_download_phase_checkpoints_from_hf(output_dir: Path, run_dir: Path) -> Path | None:
|
| 211 |
+
"""If no local checkpoint exists, try to recover the latest checkpoint
|
| 212 |
+
for this phase from the HF Hub repo we already upload to. Returns the
|
| 213 |
+
local path of the restored ``checkpoint-*`` directory, or None.
|
| 214 |
+
|
| 215 |
+
Designed to make Space restarts non-destructive: training state is
|
| 216 |
+
pushed to ``OSINT_HF_CHECKPOINT_REPO_ID`` after every phase, so on a
|
| 217 |
+
fresh container we can pull it back and resume.
|
| 218 |
+
"""
|
| 219 |
+
if _latest_local_checkpoint(output_dir) is not None:
|
| 220 |
+
return _latest_local_checkpoint(output_dir)
|
| 221 |
+
repo_id = _default_hf_checkpoint_repo_id(run_dir)
|
| 222 |
+
token = _resolve_hf_upload_token()
|
| 223 |
+
if not repo_id or not token:
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 228 |
+
except ImportError:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
repo_type = str(os.getenv("OSINT_HF_CHECKPOINT_REPO_TYPE", "model")).strip() or "model"
|
| 232 |
+
api = HfApi(token=token)
|
| 233 |
+
try:
|
| 234 |
+
files = list(api.list_repo_files(repo_id=repo_id, repo_type=repo_type))
|
| 235 |
+
except Exception as exc: # noqa: BLE001
|
| 236 |
+
print(
|
| 237 |
+
f"[self_play][resume] could not list files in {repo_id}: "
|
| 238 |
+
f"{type(exc).__name__}: {exc}",
|
| 239 |
+
flush=True,
|
| 240 |
+
)
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
phase_prefix = _hf_relative_repo_path(output_dir, run_dir)
|
| 244 |
+
phase_prefix_clean = phase_prefix.strip("/") + "/"
|
| 245 |
+
candidate_steps: dict[int, list[str]] = {}
|
| 246 |
+
for remote_path in files:
|
| 247 |
+
if not remote_path.startswith(phase_prefix_clean):
|
| 248 |
+
continue
|
| 249 |
+
relative = remote_path[len(phase_prefix_clean) :]
|
| 250 |
+
parts = relative.split("/", 1)
|
| 251 |
+
if len(parts) < 2 or not parts[0].startswith("checkpoint-"):
|
| 252 |
+
continue
|
| 253 |
+
try:
|
| 254 |
+
step = int(parts[0].split("-", 1)[1])
|
| 255 |
+
except ValueError:
|
| 256 |
+
continue
|
| 257 |
+
candidate_steps.setdefault(step, []).append(remote_path)
|
| 258 |
+
|
| 259 |
+
if not candidate_steps:
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
+
best_step = max(candidate_steps.keys())
|
| 263 |
+
target_local_dir = output_dir / f"checkpoint-{best_step}"
|
| 264 |
+
target_local_dir.mkdir(parents=True, exist_ok=True)
|
| 265 |
+
print(
|
| 266 |
+
f"[self_play][resume] downloading phase checkpoint from HF Hub: "
|
| 267 |
+
f"repo={repo_id} prefix={phase_prefix_clean}checkpoint-{best_step} "
|
| 268 |
+
f"-> {target_local_dir}",
|
| 269 |
+
flush=True,
|
| 270 |
+
)
|
| 271 |
+
try:
|
| 272 |
+
snapshot_download(
|
| 273 |
+
repo_id=repo_id,
|
| 274 |
+
repo_type=repo_type,
|
| 275 |
+
local_dir=str(output_dir),
|
| 276 |
+
allow_patterns=[f"{phase_prefix_clean}checkpoint-{best_step}/*"],
|
| 277 |
+
token=token,
|
| 278 |
+
)
|
| 279 |
+
downloaded_root = output_dir / phase_prefix_clean.rstrip("/") / f"checkpoint-{best_step}"
|
| 280 |
+
if downloaded_root.exists() and downloaded_root != target_local_dir:
|
| 281 |
+
for item in downloaded_root.iterdir():
|
| 282 |
+
dest = target_local_dir / item.name
|
| 283 |
+
if dest.exists():
|
| 284 |
+
continue
|
| 285 |
+
item.replace(dest)
|
| 286 |
+
return target_local_dir if target_local_dir.exists() else None
|
| 287 |
+
except Exception as exc: # noqa: BLE001
|
| 288 |
+
print(
|
| 289 |
+
f"[self_play][resume] failed to download checkpoint from HF Hub: "
|
| 290 |
+
f"{type(exc).__name__}: {exc}",
|
| 291 |
+
flush=True,
|
| 292 |
+
)
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
|
| 296 |
|
| 297 |
def _task_to_edge_json(task: TaskInstance) -> str:
|
| 298 |
payload = [
|
|
|
|
| 845 |
Dataset, GRPOConfig, GRPOTrainer = _require_training_stack()
|
| 846 |
|
| 847 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 848 |
+
|
| 849 |
+
phase_label = str(run_name).strip() or str(output_dir.name)
|
| 850 |
+
reward_class_name = type(reward_function).__name__
|
| 851 |
+
|
| 852 |
+
# Output layout: <run_dir>/round_NNN/<phase_subdir>. Match the run_dir
|
| 853 |
+
# used by the corresponding HF upload helpers so resume paths line up
|
| 854 |
+
# with where checkpoints were written.
|
| 855 |
+
run_dir_for_resume = output_dir.parents[1] if len(output_dir.parents) >= 2 else output_dir.parent
|
| 856 |
+
|
| 857 |
+
if _final_model_already_present(output_dir):
|
| 858 |
+
final_dir = output_dir / "final_model"
|
| 859 |
+
print(
|
| 860 |
+
f"[self_play][resume] phase={phase_label} already has final_model at {final_dir}. "
|
| 861 |
+
f"Skipping retrain on Space restart.",
|
| 862 |
+
flush=True,
|
| 863 |
+
)
|
| 864 |
+
checkpoint_dirs = [str(path) for path in sorted(output_dir.glob("checkpoint-*")) if path.is_dir()]
|
| 865 |
+
return {
|
| 866 |
+
"model_path": str(final_dir),
|
| 867 |
+
"final_model_path": str(final_dir),
|
| 868 |
+
"phase_output_dir": str(output_dir),
|
| 869 |
+
"checkpoint_dirs": checkpoint_dirs,
|
| 870 |
+
"global_step": int(getattr(phase, "max_steps", 0) or 0),
|
| 871 |
+
"training_loss": 0.0,
|
| 872 |
+
"train_rows": len(rows),
|
| 873 |
+
"tuning_mode": str(tuning_mode).strip().lower() or "full",
|
| 874 |
+
"is_full_finetune": str(tuning_mode).strip().lower() != "lora",
|
| 875 |
+
"resumed_skipped": True,
|
| 876 |
+
}
|
| 877 |
+
|
| 878 |
+
resume_checkpoint = _latest_local_checkpoint(output_dir)
|
| 879 |
+
resume_is_full_state = bool(
|
| 880 |
+
resume_checkpoint is not None
|
| 881 |
+
and (resume_checkpoint / "optimizer.pt").exists()
|
| 882 |
+
and (resume_checkpoint / "trainer_state.json").exists()
|
| 883 |
+
)
|
| 884 |
+
if resume_checkpoint is None:
|
| 885 |
+
downloaded = _maybe_download_phase_checkpoints_from_hf(output_dir, run_dir_for_resume)
|
| 886 |
+
if downloaded is not None:
|
| 887 |
+
resume_checkpoint = downloaded
|
| 888 |
+
resume_is_full_state = bool(
|
| 889 |
+
(downloaded / "optimizer.pt").exists()
|
| 890 |
+
and (downloaded / "trainer_state.json").exists()
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
# If the resume checkpoint is weights-only (e.g. recovered from HF Hub
|
| 894 |
+
# which intentionally drops optimizer state to keep uploads small),
|
| 895 |
+
# warm-start the model from those weights and start a fresh trainer
|
| 896 |
+
# state. Better than restarting from the base model.
|
| 897 |
+
warm_start_from: Path | None = None
|
| 898 |
+
if resume_checkpoint is not None and not resume_is_full_state:
|
| 899 |
+
warm_start_from = resume_checkpoint
|
| 900 |
+
resume_checkpoint = None
|
| 901 |
+
|
| 902 |
+
if warm_start_from is not None:
|
| 903 |
+
model_name_or_path = str(warm_start_from)
|
| 904 |
+
|
| 905 |
dataset = Dataset.from_list(rows)
|
| 906 |
args = _safe_build_grpo_config(
|
| 907 |
phase=phase,
|
|
|
|
| 924 |
raise RuntimeError("Installed TRL version does not expose peft_config in GRPOTrainer.")
|
| 925 |
trainer_kwargs["peft_config"] = _build_lora_config(lora)
|
| 926 |
|
|
|
|
|
|
|
| 927 |
print(
|
| 928 |
f"[self_play] Starting phase: {phase_label} rows={len(rows)} "
|
| 929 |
f"max_steps={phase.max_steps}",
|
|
|
|
| 938 |
f"per_device_train_batch_size={phase.per_device_train_batch_size}",
|
| 939 |
flush=True,
|
| 940 |
)
|
| 941 |
+
if resume_checkpoint is not None:
|
| 942 |
+
print(
|
| 943 |
+
f"[self_play][resume] phase={phase_label} resuming (full state) from checkpoint={resume_checkpoint}",
|
| 944 |
+
flush=True,
|
| 945 |
+
)
|
| 946 |
+
elif warm_start_from is not None:
|
| 947 |
+
print(
|
| 948 |
+
f"[self_play][resume] phase={phase_label} warm-starting from weights only at {warm_start_from} "
|
| 949 |
+
f"(no optimizer state available)",
|
| 950 |
+
flush=True,
|
| 951 |
+
)
|
| 952 |
strict_asserts = str(os.getenv("OSINT_TRAIN_STRICT_ASSERTS", "")).strip().lower() in {"1", "true", "yes", "on"}
|
| 953 |
trainer = GRPOTrainer(**trainer_kwargs)
|
| 954 |
tracked_params = [
|
|
|
|
| 960 |
name: float(param.detach().float().abs().mean().item())
|
| 961 |
for name, param in tracked_params
|
| 962 |
}
|
| 963 |
+
if resume_checkpoint is not None:
|
| 964 |
+
train_output = trainer.train(resume_from_checkpoint=str(resume_checkpoint))
|
| 965 |
+
else:
|
| 966 |
+
train_output = trainer.train()
|
| 967 |
|
| 968 |
final_dir = output_dir / "final_model"
|
| 969 |
trainer.save_model(str(final_dir))
|