Spaces:
Paused
Upload intermediate GRPO checkpoints to HF Hub on every save
Browse filesThe resume-from-HF-Hub fallback in _train_grpo_phase only helps if HF
Hub actually has a fresh checkpoint when a Space restarts mid-phase.
Until now uploads only happened at the end of each phase, so a crash
at e.g. step 70/120 would leave the Hub stuck at the start-of-phase
weights and force training to restart from scratch.
Add a Trainer callback _HfHubCheckpointUploadCallback registered via
trainer_kwargs["callbacks"] in _train_grpo_phase. On every Trainer
save (driven by phase.save_steps) it locates the newest local
checkpoint-* directory and pushes it to OSINT_HF_CHECKPOINT_REPO_ID
through the existing _maybe_upload_folder_to_hf helper, reusing the
ignore patterns that strip optimizer/scheduler state to keep uploads
manageable.
The callback is built lazily so transformers stays an optional
dependency. Failures are caught and logged; training never stops
because of an upload error, and the next save_steps retries. Behavior
is opt-out via OSINT_HF_UPLOAD_ON_SAVE=0 and only activates when both
a checkpoint repo id and HF token are resolvable.
This pairs with _maybe_download_phase_checkpoints_from_hf so a Space
that gets killed mid-phase can boot back up, pull the most recent
checkpoint, and warm-start instead of replaying from step 0.
Made-with: Cursor
|
@@ -179,6 +179,80 @@ def _require_training_stack() -> tuple[Any, Any, Any]:
|
|
| 179 |
return Dataset, GRPOConfig, GRPOTrainer
|
| 180 |
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
def _latest_local_checkpoint(output_dir: Path) -> Path | None:
|
| 183 |
if not output_dir.exists():
|
| 184 |
return None
|
|
@@ -918,6 +992,15 @@ def _train_grpo_phase(
|
|
| 918 |
"train_dataset": dataset,
|
| 919 |
}
|
| 920 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
if str(tuning_mode).strip().lower() == "lora":
|
| 922 |
trainer_signature = inspect.signature(GRPOTrainer.__init__)
|
| 923 |
if "peft_config" not in trainer_signature.parameters:
|
|
|
|
| 179 |
return Dataset, GRPOConfig, GRPOTrainer
|
| 180 |
|
| 181 |
|
| 182 |
+
def _build_hf_checkpoint_upload_callback(output_dir: Path, run_dir: Path) -> Any:
|
| 183 |
+
"""Return a Trainer callback that uploads each fresh ``checkpoint-*`` to
|
| 184 |
+
HF Hub the moment Transformers' Trainer writes it to disk. Returns None
|
| 185 |
+
if uploads are disabled or transformers is unavailable.
|
| 186 |
+
|
| 187 |
+
This pairs with ``_maybe_download_phase_checkpoints_from_hf`` so a Space
|
| 188 |
+
that gets restarted mid-phase can pull the most recent checkpoint and
|
| 189 |
+
warm-start instead of starting from step 0. Honors the same env vars as
|
| 190 |
+
the post-phase upload helper:
|
| 191 |
+
- ``OSINT_HF_CHECKPOINT_REPO_ID`` (or auto-derived from ``SPACE_ID``)
|
| 192 |
+
- ``OSINT_HF_CHECKPOINT_REPO_TYPE`` (default ``model``)
|
| 193 |
+
- ``OSINT_HF_UPLOAD_ON_SAVE`` (default ``1``; set to 0 to disable)
|
| 194 |
+
"""
|
| 195 |
+
if not _is_true_env(os.getenv("OSINT_HF_UPLOAD_ON_SAVE", "1")):
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
repo_id = _default_hf_checkpoint_repo_id(run_dir)
|
| 199 |
+
token = _resolve_hf_upload_token()
|
| 200 |
+
if not repo_id or not token:
|
| 201 |
+
return None
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
from transformers import TrainerCallback
|
| 205 |
+
except ImportError:
|
| 206 |
+
print(
|
| 207 |
+
"[self_play][hf_upload] transformers.TrainerCallback unavailable; "
|
| 208 |
+
"intermediate checkpoint uploads disabled.",
|
| 209 |
+
flush=True,
|
| 210 |
+
)
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
captured_output_dir = output_dir
|
| 214 |
+
captured_run_dir = run_dir
|
| 215 |
+
|
| 216 |
+
class _HfHubCheckpointUploadCallback(TrainerCallback): # type: ignore[misc]
|
| 217 |
+
"""Upload the latest local ``checkpoint-*`` directory after each save."""
|
| 218 |
+
|
| 219 |
+
def __init__(self) -> None:
|
| 220 |
+
self._last_uploaded_step: int | None = None
|
| 221 |
+
self._failures = 0
|
| 222 |
+
|
| 223 |
+
def on_save(self, args: Any, state: Any, control: Any, **kwargs: Any) -> Any: # noqa: D401
|
| 224 |
+
try:
|
| 225 |
+
latest = _latest_local_checkpoint(captured_output_dir)
|
| 226 |
+
if latest is None:
|
| 227 |
+
return control
|
| 228 |
+
step = int(latest.name.split("-", 1)[1]) if "-" in latest.name else 0
|
| 229 |
+
if self._last_uploaded_step is not None and step <= self._last_uploaded_step:
|
| 230 |
+
return control
|
| 231 |
+
print(
|
| 232 |
+
f"[self_play][hf_upload] on_save uploading {latest.name} "
|
| 233 |
+
f"to HF Hub (phase_dir={captured_output_dir.name}, step={step}).",
|
| 234 |
+
flush=True,
|
| 235 |
+
)
|
| 236 |
+
_maybe_upload_folder_to_hf(
|
| 237 |
+
latest,
|
| 238 |
+
captured_run_dir,
|
| 239 |
+
f"Intermediate checkpoint upload step={step} ({captured_output_dir.name})",
|
| 240 |
+
)
|
| 241 |
+
self._last_uploaded_step = step
|
| 242 |
+
self._failures = 0
|
| 243 |
+
except Exception as exc: # noqa: BLE001
|
| 244 |
+
self._failures += 1
|
| 245 |
+
print(
|
| 246 |
+
f"[self_play][hf_upload] on_save upload failed "
|
| 247 |
+
f"({type(exc).__name__}: {exc}). failures={self._failures}. "
|
| 248 |
+
"Continuing training; next save will retry.",
|
| 249 |
+
flush=True,
|
| 250 |
+
)
|
| 251 |
+
return control
|
| 252 |
+
|
| 253 |
+
return _HfHubCheckpointUploadCallback()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
def _latest_local_checkpoint(output_dir: Path) -> Path | None:
|
| 257 |
if not output_dir.exists():
|
| 258 |
return None
|
|
|
|
| 992 |
"train_dataset": dataset,
|
| 993 |
}
|
| 994 |
|
| 995 |
+
upload_callback = _build_hf_checkpoint_upload_callback(output_dir, run_dir_for_resume)
|
| 996 |
+
if upload_callback is not None:
|
| 997 |
+
trainer_kwargs["callbacks"] = [upload_callback]
|
| 998 |
+
print(
|
| 999 |
+
f"[self_play][hf_upload] phase={phase_label} intermediate checkpoint "
|
| 1000 |
+
f"uploads enabled (every save_steps={phase.save_steps}).",
|
| 1001 |
+
flush=True,
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
if str(tuning_mode).strip().lower() == "lora":
|
| 1005 |
trainer_signature = inspect.signature(GRPOTrainer.__init__)
|
| 1006 |
if "peft_config" not in trainer_signature.parameters:
|