Spaces:
Sleeping
Sleeping
| """Checkpoint push/pull helpers shared by SFT and GRPO. | |
| Two responsibilities: | |
| 1. **Push a local checkpoint dir to a Hugging Face Hub repo as a subfolder** | |
| (e.g. SFT writes to ``<repo>/sft/``, GRPO writes to ``<repo>/checkpoint-N/``). | |
| Returns the resulting Hub revision SHA so the caller can pin it in W&B. | |
| 2. **Discover and download the latest GRPO checkpoint** from the same repo, | |
| so a re-launched job can resume the same GRPO run rather than redoing SFT. | |
| We deliberately do NOT push raw model weights into W&B Artifacts — they live | |
| on the Hub. W&B gets a tiny **link-artifact** (one JSON metadata file with | |
| the Hub repo + revision SHA + step), which is enough for `wandb artifact | |
| get` to round-trip back to the Hub and download via `huggingface_hub`. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| _log = logging.getLogger(__name__) | |
| # Subfolder names on the Hub checkpoint repo. The SFT subfolder is fixed. | |
| # GRPO checkpoint subfolders follow Trainer's "checkpoint-{step}" convention. | |
| SFT_SUBFOLDER = "sft" | |
| GRPO_CHECKPOINT_RE = re.compile(r"^checkpoint-(\d+)$") | |
| class CheckpointHandle: | |
| """A pointer to a checkpoint on the Hub. | |
| ``revision`` is a commit SHA (not a branch) so the artifact is | |
| immutable — re-pushes to the same subfolder won't change what we | |
| resume from. | |
| """ | |
| repo_id: str | |
| subfolder: str | |
| revision: str | |
| step: Optional[int] = None # populated for GRPO checkpoint-N | |
| def hub_url(self) -> str: | |
| return f"https://huggingface.co/{self.repo_id}/tree/{self.revision}/{self.subfolder}" | |
| def push_checkpoint_to_hub( | |
| local_dir: str | os.PathLike, | |
| repo_id: str, | |
| subfolder: str, | |
| *, | |
| commit_message: str, | |
| token: Optional[str] = None, | |
| ) -> CheckpointHandle: | |
| """Upload ``local_dir`` to ``repo_id/<subfolder>/`` and return a pinned handle. | |
| Raises if the repo can't be created or upload fails — the caller decides | |
| whether to swallow the exception. | |
| """ | |
| from huggingface_hub import HfApi, create_repo | |
| local = Path(local_dir) | |
| if not local.is_dir(): | |
| raise FileNotFoundError(f"checkpoint dir does not exist: {local}") | |
| api = HfApi(token=token) | |
| create_repo(repo_id, exist_ok=True, repo_type="model", token=token) | |
| _log.info("Uploading %s -> %s/%s", local, repo_id, subfolder) | |
| commit = api.upload_folder( | |
| folder_path=str(local), | |
| repo_id=repo_id, | |
| path_in_repo=subfolder, | |
| commit_message=commit_message, | |
| token=token, | |
| ) | |
| revision = commit.oid if hasattr(commit, "oid") else str(commit) | |
| _log.info("Push complete; revision=%s", revision) | |
| return CheckpointHandle( | |
| repo_id=repo_id, | |
| subfolder=subfolder, | |
| revision=revision, | |
| ) | |
| def find_latest_grpo_checkpoint( | |
| repo_id: str, | |
| *, | |
| token: Optional[str] = None, | |
| ) -> Optional[CheckpointHandle]: | |
| """Return the highest-step ``checkpoint-N/`` folder on the repo, or None. | |
| Reads the *current* main revision (so concurrent pushes are race-free | |
| for our purposes — we never need to resume from a half-finished push). | |
| """ | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError | |
| api = HfApi(token=token) | |
| try: | |
| files = api.list_repo_files(repo_id, repo_type="model", token=token) | |
| except (RepositoryNotFoundError, RevisionNotFoundError): | |
| return None | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("Could not list %s: %s", repo_id, exc) | |
| return None | |
| best_step = -1 | |
| best_subfolder: Optional[str] = None | |
| for f in files: | |
| # Top-level folder name is the first path component. | |
| head = f.split("/", 1)[0] | |
| m = GRPO_CHECKPOINT_RE.match(head) | |
| if not m: | |
| continue | |
| step = int(m.group(1)) | |
| if step > best_step: | |
| best_step = step | |
| best_subfolder = head | |
| if best_subfolder is None: | |
| return None | |
| # Pin the revision to the current main HEAD so concurrent commits don't | |
| # surprise us partway through download. | |
| info = api.repo_info(repo_id, repo_type="model", token=token) | |
| revision = info.sha or "main" | |
| return CheckpointHandle( | |
| repo_id=repo_id, | |
| subfolder=best_subfolder, | |
| revision=revision, | |
| step=best_step, | |
| ) | |
| def download_checkpoint( | |
| handle: CheckpointHandle, | |
| local_dir: str | os.PathLike, | |
| *, | |
| token: Optional[str] = None, | |
| ) -> Path: | |
| """Download a Hub checkpoint subfolder to ``local_dir`` and return the path.""" | |
| from huggingface_hub import snapshot_download | |
| target = Path(local_dir) | |
| target.mkdir(parents=True, exist_ok=True) | |
| snapshot_download( | |
| repo_id=handle.repo_id, | |
| revision=handle.revision, | |
| allow_patterns=[f"{handle.subfolder}/*"], | |
| local_dir=str(target), | |
| token=token, | |
| ) | |
| out = target / handle.subfolder | |
| if not out.is_dir(): | |
| raise FileNotFoundError( | |
| f"download succeeded but {out} is missing — check repo layout" | |
| ) | |
| return out | |
| def log_link_artifact_to_wandb( | |
| handle: CheckpointHandle, | |
| *, | |
| artifact_name: str, | |
| extra: Optional[dict] = None, | |
| ) -> None: | |
| """Log a tiny pointer-only artifact to the active W&B run. | |
| The artifact contains a single ``checkpoint.json`` describing the Hub | |
| location and revision. No model bytes are uploaded — this is purely an | |
| addressable, versioned reference (~200 bytes) that makes the artifact | |
| panel of the W&B run usable as a checkpoint registry. | |
| """ | |
| try: | |
| import wandb | |
| except ImportError: | |
| return | |
| if wandb.run is None: | |
| return | |
| payload = { | |
| "repo_id": handle.repo_id, | |
| "subfolder": handle.subfolder, | |
| "revision": handle.revision, | |
| "step": handle.step, | |
| "hub_url": handle.hub_url, | |
| } | |
| if extra: | |
| payload.update(extra) | |
| with tempfile.TemporaryDirectory() as tmp: | |
| meta_path = Path(tmp) / "checkpoint.json" | |
| meta_path.write_text(json.dumps(payload, indent=2)) | |
| artifact = wandb.Artifact( | |
| name=artifact_name, | |
| type="model-pointer", | |
| description=f"Pointer to {handle.hub_url}", | |
| metadata=payload, | |
| ) | |
| artifact.add_file(str(meta_path)) | |
| try: | |
| wandb.run.log_artifact(artifact) | |
| except Exception as exc: # noqa: BLE001 | |
| _log.warning("W&B artifact logging failed (non-fatal): %s", exc) | |