physix / train /job_train_single.py
Pratyush-01's picture
Upload folder using huggingface_hub
0e24aff verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "unsloth",
# "trl==0.24.0",
# "transformers",
# "datasets",
# "peft",
# "accelerate",
# "bitsandbytes",
# "wandb",
# "setuptools",
# "wheel",
# "pip",
# "scipy>=1.10,<2.0",
# "sympy>=1.12,<2.0",
# "pydantic>=2.5,<3.0",
# "numpy>=1.24,<3.0",
# "openenv-core[core]>=0.2.2",
# "huggingface_hub>=0.24,<1.0",
# "matplotlib>=3.7,<4.0",
# ]
# ///
"""PhysiX RLVR single-system training job — damped_spring only.
Identical pipeline to job_train.py (SFT warm-start → GRPO) but scoped
to a single physical system (damped_spring) so the reward signal is
maximally focused and easy to observe as a clean increasing curve.
Deploy with:
hf jobs uv run job_train_single.py \
--image unsloth/unsloth:2026.3.8-pt2.9.0-vllm-0.16.0-cu12.8-studio-release \
--flavor l40sx1 \
--secrets HF_TOKEN \
--secrets WANDB_API_KEY \
-v hf://datasets/Pratyush-01/physix-live-src:/physix-live \
--timeout 2h
"""
from __future__ import annotations
import os
import shutil
import subprocess
import sys
from pathlib import Path
SYSTEM_ID = "damped_spring"
PROFILE: dict = {
"base_model": "Qwen/Qwen2.5-3B-Instruct",
"sft_lora_r": "32",
"grpo_lora_r": "32",
"sft_lr": "1.5e-5",
"grpo_lr": "3e-6",
"sft_epochs": "3",
"num_steps": "200",
"num_generations": "4",
"max_completion": "256",
# Separate repos so this run never touches the 3-system checkpoints.
"hub_final_repo": "Pratyush-01/physix-3b-rl-damped",
"hub_ckpt_repo": "Pratyush-01/physix-3b-rl-damped-ckpt",
"sft_run_name": "physix-sft-3b-damped",
"grpo_run_name": "physix-grpo-3b-damped",
}
# ---------------------------------------------------------------------------
# Environment hardening (same as job_train.py)
# ---------------------------------------------------------------------------
def _harden_env() -> None:
os.environ.setdefault("USER", "physix")
os.environ.setdefault("LOGNAME", "physix")
os.environ.setdefault("HOME", "/tmp/home")
os.environ.setdefault("HF_HOME", "/tmp/hf_cache")
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/tmp/torchinductor_cache")
os.environ.setdefault("TRITON_CACHE_DIR", "/tmp/triton_cache")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/xdg-cache")
os.environ.setdefault("WANDB_DIR", "/tmp/wandb")
os.environ.setdefault("WANDB_CACHE_DIR", "/tmp/wandb-cache")
os.environ.setdefault("WANDB_DATA_DIR", "/tmp/wandb-data")
os.environ.setdefault("WANDB_ARTIFACT_DIR", "/tmp/wandb-artifacts")
os.environ.setdefault("WANDB_CONFIG_DIR", "/tmp/wandb-config")
os.environ.setdefault("WANDB_DISABLE_ARTIFACTS", "true")
os.environ.setdefault("WANDB_LOG_MODEL", "false")
os.environ.setdefault("WANDB_PROJECT", "physix-live")
os.environ.setdefault("UNSLOTH_COMPILE_DISABLE", "1")
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTHONUNBUFFERED", "1")
if os.environ.get("HF_TOKEN"):
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", os.environ["HF_TOKEN"])
for d in (
os.environ["HOME"],
os.environ["HF_HOME"],
os.environ["TORCHINDUCTOR_CACHE_DIR"],
os.environ["TRITON_CACHE_DIR"],
os.environ["XDG_CACHE_HOME"],
os.environ["WANDB_DIR"],
os.environ["WANDB_CACHE_DIR"],
os.environ["WANDB_DATA_DIR"],
os.environ["WANDB_ARTIFACT_DIR"],
os.environ["WANDB_CONFIG_DIR"],
):
Path(d).mkdir(parents=True, exist_ok=True)
def _banner(msg: str) -> None:
line = "=" * 72
print(f"\n{line}\n {msg}\n{line}", flush=True)
def _run(cmd: list[str], *, env: dict | None = None) -> None:
print(f"$ {' '.join(cmd)}", flush=True)
subprocess.run(cmd, check=True, env=env or os.environ.copy())
def _require(name: str) -> str:
val = os.environ.get(name)
if not val:
sys.exit(f"FATAL: required secret {name!r} is not set on the job")
return val
def _stage_physix_live() -> Path:
src = Path("/physix-live")
if not src.exists():
sys.exit(
"FATAL: expected physix-live source mounted at /physix-live. "
"Pass `-v hf://datasets/<user>/physix-live-src:/physix-live` "
"when submitting the job."
)
dst = Path("/tmp/src/physix-live")
if dst.exists():
shutil.rmtree(dst)
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(src, dst)
return dst
def _install_physix(repo: Path) -> None:
install_args = ["--no-cache-dir", "-e", str(repo), "--no-deps"]
try:
_run(["uv", "pip", "install", "--python", sys.executable, *install_args])
return
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
print(f"[install] uv pip path failed ({exc!r}); bootstrapping pip via ensurepip", flush=True)
_run([sys.executable, "-m", "ensurepip", "--upgrade"])
_run([sys.executable, "-m", "pip", "install", *install_args])
def _sanity_check_imports() -> None:
print("--- Sanity import check ---", flush=True)
code = (
"import torch, trl, transformers, datasets, wandb, unsloth, physix; "
"print(f'torch={torch.__version__} cuda={torch.cuda.is_available()} "
"device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}'); "
"print(f'unsloth={unsloth.__version__} trl={trl.__version__} "
"transformers={transformers.__version__} datasets={datasets.__version__}'); "
"print(f'physix loaded from {physix.__file__}'); "
"assert trl.__version__ == '0.24.0', f'trl must be pinned to 0.24.0, got {trl.__version__}'"
)
_run([sys.executable, "-c", code])
def _gpu_check() -> None:
print("--- GPU check ---", flush=True)
try:
subprocess.run(["nvidia-smi"], check=True)
except FileNotFoundError:
sys.exit("FATAL: nvidia-smi missing — job hardware is not GPU")
# ---------------------------------------------------------------------------
# SFT + GRPO steps, each locked to SYSTEM_ID
# ---------------------------------------------------------------------------
def _run_sft() -> None:
p = PROFILE
_banner(f"Step 1/2: SFT warm-start ({p['base_model']}) — system: {SYSTEM_ID}")
_run([
sys.executable, "-m", "physix.training.sft",
"--model", p["base_model"],
"--output-dir", "/tmp/physix-sft-damped",
"--epochs", p["sft_epochs"],
"--instances-per-system", "32",
"--system-ids", SYSTEM_ID,
"--lora-r", p["sft_lora_r"],
"--learning-rate", p["sft_lr"],
"--wandb-run-name", p["sft_run_name"],
"--hub-checkpoint-repo-id", p["hub_ckpt_repo"],
"--seed", "0",
])
def _run_grpo() -> None:
p = PROFILE
_banner(f"Step 2/2: GRPO RLVR ({p['num_steps']} steps) — system: {SYSTEM_ID}")
_run([
sys.executable, "-m", "physix.training.loop",
"--model", p["base_model"],
"--output-dir", "/tmp/physix-grpo-damped",
"--num-steps", p["num_steps"],
"--num-generations", p["num_generations"],
"--max-completion-length", p["max_completion"],
"--learning-rate", p["grpo_lr"],
"--instances-per-system", "32",
"--system-ids", SYSTEM_ID,
"--lora-r", p["grpo_lora_r"],
"--save-method", "merged_16bit",
"--push-to-hub",
"--hub-repo-id", p["hub_final_repo"],
"--hub-checkpoint-repo-id", p["hub_ckpt_repo"],
"--wandb-project", "physix-live",
"--wandb-run-name", p["grpo_run_name"],
"--sft-checkpoint", "/tmp/physix-sft-damped/merged",
"--seed", "0",
])
def main() -> None:
_harden_env()
_banner(
f"PhysiX RLVR single-system job\n"
f" system: {SYSTEM_ID}\n"
f" model: {PROFILE['base_model']}\n"
f" steps: {PROFILE['num_steps']}\n"
f" wandb: physix-live / {PROFILE['grpo_run_name']}"
)
_require("HF_TOKEN")
_require("WANDB_API_KEY")
_gpu_check()
repo = _stage_physix_live()
_install_physix(repo)
_sanity_check_imports()
_run_sft()
_run_grpo()
_banner("DONE")
print(
f"System trained on → {SYSTEM_ID}\n"
f"Final model → https://huggingface.co/{PROFILE['hub_final_repo']}\n"
f"Checkpoints → https://huggingface.co/{PROFILE['hub_ckpt_repo']}\n"
f"W&B project → https://wandb.ai/pratyush01/physix-live\n",
flush=True,
)
if __name__ == "__main__":
main()