siddeshwar-kagatikar commited on
Commit
b0b586a
·
1 Parent(s): 2e14f6d

Upload intermediate GRPO checkpoints to HF Hub on every save

Browse files

The resume-from-HF-Hub fallback in _train_grpo_phase only helps if HF
Hub actually has a fresh checkpoint when a Space restarts mid-phase.
Until now uploads only happened at the end of each phase, so a crash
at e.g. step 70/120 would leave the Hub stuck at the start-of-phase
weights and force training to restart from scratch.

Add a Trainer callback _HfHubCheckpointUploadCallback registered via
trainer_kwargs["callbacks"] in _train_grpo_phase. On every Trainer
save (driven by phase.save_steps) it locates the newest local
checkpoint-* directory and pushes it to OSINT_HF_CHECKPOINT_REPO_ID
through the existing _maybe_upload_folder_to_hf helper, reusing the
ignore patterns that strip optimizer/scheduler state to keep uploads
manageable.

The callback is built lazily so transformers stays an optional
dependency. Failures are caught and logged; training never stops
because of an upload error, and the next save_steps retries. Behavior
is opt-out via OSINT_HF_UPLOAD_ON_SAVE=0 and only activates when both
a checkpoint repo id and HF token are resolvable.

This pairs with _maybe_download_phase_checkpoints_from_hf so a Space
that gets killed mid-phase can boot back up, pull the most recent
checkpoint, and warm-start instead of replaying from step 0.

Made-with: Cursor

Files changed (1) hide show
  1. src/osint_env/training/self_play.py +83 -0
src/osint_env/training/self_play.py CHANGED
@@ -179,6 +179,80 @@ def _require_training_stack() -> tuple[Any, Any, Any]:
179
  return Dataset, GRPOConfig, GRPOTrainer
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def _latest_local_checkpoint(output_dir: Path) -> Path | None:
183
  if not output_dir.exists():
184
  return None
@@ -918,6 +992,15 @@ def _train_grpo_phase(
918
  "train_dataset": dataset,
919
  }
920
 
 
 
 
 
 
 
 
 
 
921
  if str(tuning_mode).strip().lower() == "lora":
922
  trainer_signature = inspect.signature(GRPOTrainer.__init__)
923
  if "peft_config" not in trainer_signature.parameters:
 
179
  return Dataset, GRPOConfig, GRPOTrainer
180
 
181
 
182
+ def _build_hf_checkpoint_upload_callback(output_dir: Path, run_dir: Path) -> Any:
183
+ """Return a Trainer callback that uploads each fresh ``checkpoint-*`` to
184
+ HF Hub the moment Transformers' Trainer writes it to disk. Returns None
185
+ if uploads are disabled or transformers is unavailable.
186
+
187
+ This pairs with ``_maybe_download_phase_checkpoints_from_hf`` so a Space
188
+ that gets restarted mid-phase can pull the most recent checkpoint and
189
+ warm-start instead of starting from step 0. Honors the same env vars as
190
+ the post-phase upload helper:
191
+ - ``OSINT_HF_CHECKPOINT_REPO_ID`` (or auto-derived from ``SPACE_ID``)
192
+ - ``OSINT_HF_CHECKPOINT_REPO_TYPE`` (default ``model``)
193
+ - ``OSINT_HF_UPLOAD_ON_SAVE`` (default ``1``; set to 0 to disable)
194
+ """
195
+ if not _is_true_env(os.getenv("OSINT_HF_UPLOAD_ON_SAVE", "1")):
196
+ return None
197
+
198
+ repo_id = _default_hf_checkpoint_repo_id(run_dir)
199
+ token = _resolve_hf_upload_token()
200
+ if not repo_id or not token:
201
+ return None
202
+
203
+ try:
204
+ from transformers import TrainerCallback
205
+ except ImportError:
206
+ print(
207
+ "[self_play][hf_upload] transformers.TrainerCallback unavailable; "
208
+ "intermediate checkpoint uploads disabled.",
209
+ flush=True,
210
+ )
211
+ return None
212
+
213
+ captured_output_dir = output_dir
214
+ captured_run_dir = run_dir
215
+
216
+ class _HfHubCheckpointUploadCallback(TrainerCallback): # type: ignore[misc]
217
+ """Upload the latest local ``checkpoint-*`` directory after each save."""
218
+
219
+ def __init__(self) -> None:
220
+ self._last_uploaded_step: int | None = None
221
+ self._failures = 0
222
+
223
+ def on_save(self, args: Any, state: Any, control: Any, **kwargs: Any) -> Any: # noqa: D401
224
+ try:
225
+ latest = _latest_local_checkpoint(captured_output_dir)
226
+ if latest is None:
227
+ return control
228
+ step = int(latest.name.split("-", 1)[1]) if "-" in latest.name else 0
229
+ if self._last_uploaded_step is not None and step <= self._last_uploaded_step:
230
+ return control
231
+ print(
232
+ f"[self_play][hf_upload] on_save uploading {latest.name} "
233
+ f"to HF Hub (phase_dir={captured_output_dir.name}, step={step}).",
234
+ flush=True,
235
+ )
236
+ _maybe_upload_folder_to_hf(
237
+ latest,
238
+ captured_run_dir,
239
+ f"Intermediate checkpoint upload step={step} ({captured_output_dir.name})",
240
+ )
241
+ self._last_uploaded_step = step
242
+ self._failures = 0
243
+ except Exception as exc: # noqa: BLE001
244
+ self._failures += 1
245
+ print(
246
+ f"[self_play][hf_upload] on_save upload failed "
247
+ f"({type(exc).__name__}: {exc}). failures={self._failures}. "
248
+ "Continuing training; next save will retry.",
249
+ flush=True,
250
+ )
251
+ return control
252
+
253
+ return _HfHubCheckpointUploadCallback()
254
+
255
+
256
  def _latest_local_checkpoint(output_dir: Path) -> Path | None:
257
  if not output_dir.exists():
258
  return None
 
992
  "train_dataset": dataset,
993
  }
994
 
995
+ upload_callback = _build_hf_checkpoint_upload_callback(output_dir, run_dir_for_resume)
996
+ if upload_callback is not None:
997
+ trainer_kwargs["callbacks"] = [upload_callback]
998
+ print(
999
+ f"[self_play][hf_upload] phase={phase_label} intermediate checkpoint "
1000
+ f"uploads enabled (every save_steps={phase.save_steps}).",
1001
+ flush=True,
1002
+ )
1003
+
1004
  if str(tuning_mode).strip().lower() == "lora":
1005
  trainer_signature = inspect.signature(GRPOTrainer.__init__)
1006
  if "peft_config" not in trainer_signature.parameters: