Spaces:
Sleeping
Sleeping
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "unsloth", | |
| # "trl==0.24.0", | |
| # "transformers", | |
| # "datasets", | |
| # "peft", | |
| # "accelerate", | |
| # "bitsandbytes", | |
| # "wandb", | |
| # # setuptools/wheel/pip aren't ML deps but torch._inductor.cpp_builder | |
| # # imports them at runtime when probing CPU SIMD ISA inside the very | |
| # # first GRPO training step. Missing them => "ModuleNotFoundError: | |
| # # No module named 'setuptools'" deep in compile_fx. Add them defensively | |
| # # alongside the env-level torch.compile disable below. | |
| # "setuptools", | |
| # "wheel", | |
| # "pip", | |
| # "scipy>=1.10,<2.0", | |
| # "sympy>=1.12,<2.0", | |
| # "pydantic>=2.5,<3.0", | |
| # "numpy>=1.24,<3.0", | |
| # "openenv-core[core]>=0.2.2", | |
| # "huggingface_hub>=0.24,<1.0", | |
| # "matplotlib>=3.7,<4.0", | |
| # ] | |
| # /// | |
| """PhysiX RLVR training driver for HF Jobs. | |
| Source is mounted from a dataset at /physix-live. | |
| Pins trl==0.24.0 (newer versions break Unsloth's PatchFastRL). | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| def _harden_env() -> None: | |
| os.environ.setdefault("USER", "physix") | |
| os.environ.setdefault("LOGNAME", "physix") | |
| os.environ.setdefault("HOME", "/tmp/home") | |
| os.environ.setdefault("HF_HOME", "/tmp/hf_cache") | |
| os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/tmp/torchinductor_cache") | |
| os.environ.setdefault("TRITON_CACHE_DIR", "/tmp/triton_cache") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/tmp/xdg-cache") | |
| os.environ.setdefault("WANDB_DIR", "/tmp/wandb") | |
| os.environ.setdefault("WANDB_CACHE_DIR", "/tmp/wandb-cache") | |
| os.environ.setdefault("WANDB_DATA_DIR", "/tmp/wandb-data") | |
| os.environ.setdefault("WANDB_ARTIFACT_DIR", "/tmp/wandb-artifacts") | |
| os.environ.setdefault("WANDB_CONFIG_DIR", "/tmp/wandb-config") | |
| os.environ.setdefault("WANDB_DISABLE_ARTIFACTS", "true") | |
| os.environ.setdefault("WANDB_LOG_MODEL", "false") | |
| os.environ.setdefault("WANDB_PROJECT", "physix-live") | |
| # Disable torch.compile — not needed for 3B-LoRA and breaks inductor in some containers. | |
| os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1") | |
| os.environ.setdefault("TORCH_COMPILE_DISABLE", "1") | |
| os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1") | |
| os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| os.environ.setdefault("PYTHONUNBUFFERED", "1") | |
| if os.environ.get("HF_TOKEN"): | |
| os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", os.environ["HF_TOKEN"]) | |
| for d in ( | |
| os.environ["HOME"], | |
| os.environ["HF_HOME"], | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"], | |
| os.environ["TRITON_CACHE_DIR"], | |
| os.environ["XDG_CACHE_HOME"], | |
| os.environ["WANDB_DIR"], | |
| os.environ["WANDB_CACHE_DIR"], | |
| os.environ["WANDB_DATA_DIR"], | |
| os.environ["WANDB_ARTIFACT_DIR"], | |
| os.environ["WANDB_CONFIG_DIR"], | |
| ): | |
| Path(d).mkdir(parents=True, exist_ok=True) | |
| def _banner(msg: str) -> None: | |
| line = "=" * 72 | |
| print(f"\n{line}\n {msg}\n{line}", flush=True) | |
| def _run(cmd: list[str], *, env: dict | None = None) -> None: | |
| print(f"$ {' '.join(cmd)}", flush=True) | |
| subprocess.run(cmd, check=True, env=env or os.environ.copy()) | |
| def _require(name: str) -> str: | |
| val = os.environ.get(name) | |
| if not val: | |
| sys.exit(f"FATAL: required secret {name!r} is not set on the job") | |
| return val | |
| def _stage_physix_live() -> Path: | |
| """The dataset is mounted read-only at /physix-live. pip install -e | |
| needs a writable tree (it creates an .egg-info), so copy to /tmp/src | |
| and install from there.""" | |
| src = Path("/physix-live") | |
| if not src.exists(): | |
| sys.exit( | |
| "FATAL: expected physix-live source mounted at /physix-live. " | |
| "Pass `-v hf://datasets/<user>/physix-live-src:/physix-live` " | |
| "when submitting the job." | |
| ) | |
| dst = Path("/tmp/src/physix-live") | |
| if dst.exists(): | |
| shutil.rmtree(dst) | |
| dst.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copytree(src, dst) | |
| return dst | |
| def _install_physix(repo: Path) -> None: | |
| # The base image already pins torch / transformers / unsloth / trl etc. | |
| # --no-deps prevents pip from upgrading any of them. | |
| # | |
| # The Unsloth uv-managed environment does NOT ship pip-the-module by | |
| # default (`python -m pip` raises "No module named pip"). Try the `uv | |
| # pip` shim first (uv is guaranteed to be on PATH under `hf jobs uv | |
| # run`); if that fails for any reason, bootstrap pip via ensurepip and | |
| # fall back. Either path uses --no-deps so the carefully pinned | |
| # torch/transformers/unsloth/trl in the base image stay untouched. | |
| install_args = ["--no-cache-dir", "-e", str(repo), "--no-deps"] | |
| try: | |
| _run(["uv", "pip", "install", "--python", sys.executable, *install_args]) | |
| except (subprocess.CalledProcessError, FileNotFoundError) as exc: | |
| print(f"[install] uv pip failed ({exc!r}); falling back to ensurepip", flush=True) | |
| _run([sys.executable, "-m", "ensurepip", "--upgrade"]) | |
| _run([sys.executable, "-m", "pip", "install", *install_args]) | |
| def _sanity_check_imports() -> None: | |
| print("--- Sanity import check ---", flush=True) | |
| code = ( | |
| "import torch, trl, transformers, datasets, wandb, unsloth, physix; " | |
| "print(f'torch={torch.__version__} cuda={torch.cuda.is_available()} " | |
| "device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}'); " | |
| "print(f'unsloth={unsloth.__version__} trl={trl.__version__} " | |
| "transformers={transformers.__version__} datasets={datasets.__version__}'); " | |
| "print(f'physix loaded from {physix.__file__}'); " | |
| "assert trl.__version__ == '0.24.0', f'trl must be pinned to 0.24.0, got {trl.__version__}'" | |
| ) | |
| _run([sys.executable, "-c", code]) | |
| def _gpu_check() -> None: | |
| print("--- GPU check ---", flush=True) | |
| try: | |
| subprocess.run(["nvidia-smi"], check=True) | |
| except FileNotFoundError: | |
| sys.exit("FATAL: nvidia-smi missing — job hardware is not GPU") | |
| PROFILES: dict[str, dict] = { | |
| "1.5b": { | |
| "base_model": "Qwen/Qwen2.5-1.5B-Instruct", | |
| "sft_lora_r": "32", | |
| "grpo_lora_r": "32", | |
| "sft_lr": "2e-5", | |
| "grpo_lr": "5e-6", | |
| "sft_epochs": "3", | |
| "num_steps": "300", | |
| "num_generations": "4", | |
| "max_completion": "256", | |
| "hub_final_repo": "Pratyush-01/physix-1.5b-rl", | |
| "hub_ckpt_repo": "Pratyush-01/physix-1.5b-rl-ckpt", | |
| "sft_run_name": "physix-sft-1.5b", | |
| "grpo_run_name": "physix-grpo-1.5b", | |
| }, | |
| "3b": { | |
| "base_model": "Qwen/Qwen2.5-3B-Instruct", | |
| "sft_lora_r": "32", | |
| "grpo_lora_r": "32", | |
| "sft_lr": "1.5e-5", | |
| "grpo_lr": "1e-5", # 3e-6=flat, 3e-5=too fast, 1e-5=smooth # 3e-6=flat, 3e-5=too fast, 1e-5=smooth | |
| "sft_epochs": "4", | |
| "num_steps": "200", | |
| "num_generations": "4", | |
| "max_completion": "384", | |
| "hub_final_repo": "Pratyush-01/physix-3b-rl", | |
| "hub_ckpt_repo": "Pratyush-01/physix-3b-rl-ckpt", | |
| "sft_run_name": "physix-sft-3b-final", | |
| "grpo_run_name": "physix-grpo-3b-final", | |
| }, | |
| "7b": { | |
| "base_model": "Qwen/Qwen2.5-7B-Instruct", | |
| # Smaller LoRA rank: 7B has ~4.6× more params than 1.5B so even | |
| # at r=16 the trainable count (~40M) is comparable to 1.5B at r=32. | |
| "sft_lora_r": "16", | |
| "grpo_lora_r": "16", | |
| # Lower LR for the bigger base. | |
| "sft_lr": "1e-5", | |
| "grpo_lr": "2e-6", | |
| "sft_epochs": "3", | |
| "num_steps": "200", | |
| "num_generations": "4", | |
| "max_completion": "256", | |
| "hub_final_repo": "Pratyush-01/physix-7b-rl", | |
| "hub_ckpt_repo": "Pratyush-01/physix-7b-rl-ckpt", | |
| "sft_run_name": "physix-sft-7b", | |
| "grpo_run_name": "physix-grpo-7b", | |
| }, | |
| } | |
| #: Active profile. ``3b`` chosen for the fast-iteration run — best | |
| #: capacity/wall-clock tradeoff for the PhysiX 3-system POC. | |
| ACTIVE_PROFILE: str = "3b" | |
| def _profile() -> dict: | |
| return PROFILES[ACTIVE_PROFILE] | |
| def _run_sft() -> None: | |
| p = _profile() | |
| _banner(f"Step 1/2: SFT warm-start ({p['base_model']})") | |
| _run([ | |
| sys.executable, "-m", "physix.training.sft", | |
| "--model", p["base_model"], | |
| "--output-dir", "/tmp/physix-sft", | |
| "--epochs", p["sft_epochs"], | |
| "--instances-per-system", "64", | |
| "--lora-r", p["sft_lora_r"], | |
| "--learning-rate", p["sft_lr"], | |
| "--wandb-run-name", p["sft_run_name"], | |
| # Push the merged SFT model to the same checkpoint repo GRPO uses, | |
| # under <repo>/sft. Lets a future restart skip SFT and reuse it. | |
| "--hub-checkpoint-repo-id", p["hub_ckpt_repo"], | |
| "--seed", "0", | |
| ]) | |
| def _try_resume_from_grpo_checkpoint() -> tuple[Path | None, str | None]: | |
| """Look for a prior GRPO checkpoint in the Hub repo for this profile. | |
| Returns ``(local_path, wandb_run_id)`` if a checkpoint was found and | |
| successfully downloaded, else ``(None, None)``. The downloaded | |
| directory is what gets passed to ``--resume-from-checkpoint``; the | |
| run id (when present) is set as ``WANDB_RUN_ID`` so the GRPO chart | |
| continues on the same timeline rather than starting fresh. | |
| """ | |
| p = _profile() | |
| repo_id = p["hub_ckpt_repo"] | |
| try: | |
| from physix.training.checkpoints import ( | |
| download_checkpoint, | |
| find_latest_grpo_checkpoint, | |
| ) | |
| except ImportError as exc: | |
| print(f"[resume] checkpoints helper not importable yet: {exc}", flush=True) | |
| return None, None | |
| token = os.environ.get("HF_TOKEN") | |
| handle = find_latest_grpo_checkpoint(repo_id, token=token) | |
| if handle is None: | |
| print(f"[resume] No prior GRPO checkpoint in {repo_id}; cold start.", flush=True) | |
| return None, None | |
| print( | |
| f"[resume] Found prior GRPO checkpoint at {handle.hub_url} (step={handle.step}). " | |
| f"Downloading to /tmp/physix-grpo-resume ...", | |
| flush=True, | |
| ) | |
| local = download_checkpoint(handle, "/tmp/physix-grpo-resume", token=token) | |
| # Look up the W&B run id stashed at repo root by the on_train_begin | |
| # callback. If present, we'll pass it through so wandb.init resumes | |
| # the same run and the loss/reward charts stay continuous. | |
| run_id: str | None = None | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| run_id_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="wandb_run_id.txt", | |
| repo_type="model", | |
| token=token, | |
| ) | |
| run_id = Path(run_id_path).read_text().strip() or None | |
| if run_id: | |
| print(f"[resume] W&B run id {run_id} — chart will continue on the same timeline.", flush=True) | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"[resume] No wandb_run_id.txt on repo (will start fresh W&B run): {exc}", flush=True) | |
| return local, run_id | |
| def _run_grpo( | |
| *, | |
| lora_adapter_repo: str | None = None, | |
| resume_from_checkpoint: Path | None = None, | |
| ) -> None: | |
| """Run the GRPO step. | |
| Three modes (mutually exclusive): | |
| - Cold start (default): warm from /tmp/physix-sft/merged. | |
| - From an existing Hub LoRA adapter: ``lora_adapter_repo`` set. | |
| - Resume from a prior in-flight ckpt: ``resume_from_checkpoint`` set | |
| (continues the SAME wandb run id when one is published on the repo). | |
| Reward set (physix.training.reward_fns): | |
| match, match_dense, correctness, simplicity, format | |
| Anti-hack invariants (RCA from 5kuqns9x): | |
| - ``progress`` removed (duplicated ``match`` in single-turn). | |
| - ``simplicity`` gated on R² ≥ 0.10. | |
| - ``format`` requires simulation success, not just parse success. | |
| - Three correctness-shaped signals dominate the GRPO advantage. | |
| """ | |
| p = _profile() | |
| num_steps = int(p["num_steps"]) | |
| _banner(f"GRPO RLVR ({num_steps} steps on {p['base_model']})") | |
| cmd = [ | |
| sys.executable, "-m", "physix.training.loop", | |
| "--model", p["base_model"], | |
| "--output-dir", "/tmp/physix-grpo", | |
| "--num-steps", str(num_steps), | |
| "--num-generations", p["num_generations"], | |
| "--max-completion-length", p["max_completion"], | |
| "--learning-rate", p["grpo_lr"], | |
| "--instances-per-system", "64", | |
| "--lora-r", p["grpo_lora_r"], | |
| "--save-method", "merged_16bit", | |
| "--push-to-hub", | |
| "--hub-repo-id", p["hub_final_repo"], | |
| "--hub-checkpoint-repo-id", p["hub_ckpt_repo"], | |
| "--wandb-project", "physix-live", | |
| "--wandb-run-name", p["grpo_run_name"], | |
| "--early-stop-patience", "50", | |
| "--seed", "0", | |
| ] | |
| if resume_from_checkpoint is not None: | |
| cmd += ["--resume-from-checkpoint", str(resume_from_checkpoint)] | |
| elif lora_adapter_repo: | |
| cmd += ["--lora-adapter-repo", lora_adapter_repo] | |
| else: | |
| cmd += ["--sft-checkpoint", "/tmp/physix-sft/merged"] | |
| _run(cmd) | |
| # --------------------------------------------------------------------------- | |
| # Resume configuration (baked in deliberately). | |
| # | |
| # We ship resume parameters as module-level constants instead of `-e` env | |
| # flags because `hf jobs uv run -e KEY=VAL` was observed to silently drop | |
| # env entries on submission (the job spec's `environment` dict ends up | |
| # containing only the auto-injected LOCAL_FILES_ENCODED). The script | |
| # encoding is reliable, so embedding the constants here is the | |
| # fail-safe path. | |
| # | |
| # To do a fresh run instead, set RESUME_LORA_REPO to None. | |
| # | |
| # Note: we deliberately do NOT resume into the SAME W&B run id this time | |
| # (RESUME_WANDB_RUN_ID = None). The previous run 5kuqns9x logged 4 reward | |
| # components; this one logs 3 (no_progress). Continuing the same chart | |
| # would mix two different reward setups on one timeline, which is | |
| # misleading. Instead we start a fresh run and link back to the source | |
| # run via wandb config + summary. | |
| # --------------------------------------------------------------------------- | |
| #: When set, skip SFT and warm-start GRPO from this Hub LoRA adapter. | |
| #: Must be ``None`` when switching base models — a 1.5B adapter cannot | |
| #: be loaded onto a 7B base. Only set this to resume the *same* model | |
| #: family from a prior interrupted run. | |
| RESUME_LORA_REPO: str | None = None | |
| RESUME_FROM_WANDB_RUN: str | None = None # informational only (link) | |
| def main() -> None: | |
| _harden_env() | |
| if RESUME_FROM_WANDB_RUN: | |
| # Pin the source run as W&B config so the new run's Overview tab | |
| # shows the lineage. We do NOT set WANDB_RUN_ID here. | |
| os.environ["WANDB_RESUMED_FROM"] = RESUME_FROM_WANDB_RUN | |
| print( | |
| f"[resume] Warm-starting from W&B run {RESUME_FROM_WANDB_RUN} " | |
| f"(https://wandb.ai/pratyush01/physix-live/runs/{RESUME_FROM_WANDB_RUN})", | |
| flush=True, | |
| ) | |
| resume_lora = RESUME_LORA_REPO | |
| p = _profile() | |
| if resume_lora: | |
| _banner( | |
| f"PhysiX RLVR RESUME job ({ACTIVE_PROFILE} on A100-large)\n" | |
| f" adapter: {resume_lora}\n" | |
| f" steps: {p['num_steps']}\n" | |
| f" wandb: {os.environ.get('WANDB_RUN_ID', '<new>')}" | |
| ) | |
| else: | |
| _banner( | |
| f"PhysiX RLVR training job ({ACTIVE_PROFILE} / {p['base_model']} on A100-large)" | |
| ) | |
| _require("HF_TOKEN") | |
| _require("WANDB_API_KEY") | |
| _gpu_check() | |
| repo = _stage_physix_live() | |
| _install_physix(repo) | |
| _sanity_check_imports() | |
| if resume_lora: | |
| # Forced LoRA resume (RESUME_LORA_REPO set above) — skip SFT and | |
| # warm-start GRPO from a specific Hub adapter, fresh wandb run. | |
| _run_grpo(lora_adapter_repo=resume_lora) | |
| else: | |
| # Auto-resume: if a prior GRPO checkpoint already exists in the | |
| # checkpoint repo (e.g. previous job died at step 87), pick up | |
| # where it left off and continue the SAME wandb run id so the | |
| # loss/reward chart is one continuous line. If nothing's there, | |
| # do the normal SFT -> GRPO cold start. | |
| ckpt_local, prior_run_id = _try_resume_from_grpo_checkpoint() | |
| if ckpt_local is not None: | |
| if prior_run_id: | |
| # wandb.init(resume="allow") inside loop.py picks this up. | |
| os.environ["WANDB_RUN_ID"] = prior_run_id | |
| os.environ["WANDB_RESUME"] = "allow" | |
| _run_grpo(resume_from_checkpoint=ckpt_local) | |
| else: | |
| _run_sft() | |
| _run_grpo() | |
| _banner("DONE") | |
| print( | |
| f"Final model → https://huggingface.co/{p['hub_final_repo']}\n" | |
| f"Checkpoints → https://huggingface.co/{p['hub_ckpt_repo']}\n" | |
| f"W&B project → https://wandb.ai/pratyush01/physix-live\n", | |
| flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |