rhythm_env / scripts /sft_on_hf.py
InosLihka's picture
Add SKIP_EVAL flag to sft_on_hf.py for faster training-only runs
ff20f02
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "transformers==4.56.2",
# "trl==0.22.2",
# "datasets",
# "peft",
# "accelerate",
# "bitsandbytes",
# "unsloth",
# "openenv-core",
# "fastapi",
# "uvicorn",
# "pydantic",
# "huggingface_hub",
# ]
# ///
"""
HF Jobs orchestrator for SFT prime stage.
Submits the SFT prime training as an HF Jobs run. Clones the rhythm_env
HF Space, downloads the teacher trajectory JSONL files from a HF dataset
or model repo, runs training/sft_prime.py, and uploads the SFT'd model.
Submit from local with:
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN \\
-e TEACHER_DATA_REPO=InosLihka/rhythm-env-teacher-trajectories \\
-e MODEL_REPO_SUFFIX=sft-primed \\
-e EPOCHS=2 \\
-d scripts/sft_on_hf.py
Cost on a10g-large at $1.50/hr: ~$2-3 for ~30-45 min training.
"""
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
WORK_DIR = "/tmp/rhythm_env"
OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythm-env-sft-primed"
# Teacher trajectory data must be uploaded to a HF dataset/model repo before
# this job runs (HF Jobs containers don't have access to local files). The
# repo should contain the teacher_*.jsonl files at its root.
TEACHER_DATA_REPO = os.environ.get(
"TEACHER_DATA_REPO",
"InosLihka/rhythm-env-teacher-trajectories",
)
TEACHER_FILES = os.environ.get(
"TEACHER_FILES",
"teacher_30ep_validation.jsonl,teacher_indist_30_99.jsonl,teacher_ood_10000_10049.jsonl",
).split(",")
EPOCHS = int(os.environ.get("EPOCHS", "2"))
MAX_STEPS = int(os.environ.get("MAX_STEPS", "-1")) # -1 = use epochs
LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "2048"))
SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "sft-primed")
DEFAULT_REPO = f"InosLihka/rhythm-env-meta-trained-{SUFFIX}"
MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)
print("=== SFT prime config ===")
print(f" TEACHER_DATA_REPO: {TEACHER_DATA_REPO}")
print(f" TEACHER_FILES: {TEACHER_FILES}")
print(f" EPOCHS={EPOCHS}, MAX_STEPS={MAX_STEPS}, LORA_RANK={LORA_RANK}")
print(f" LR={LEARNING_RATE}, MAX_SEQ_LENGTH={MAX_SEQ_LENGTH}")
print(f" MODEL_REPO={MODEL_REPO}")
print()
def run(cmd):
print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
subprocess.run(cmd, check=True)
def main():
# 1. Clone repo
if Path(WORK_DIR).exists():
shutil.rmtree(WORK_DIR)
run(["git", "clone", REPO_URL, WORK_DIR])
os.chdir(WORK_DIR)
sys.path.insert(0, WORK_DIR)
sys.path.insert(0, os.path.join(WORK_DIR, "training"))
# 2. Download teacher trajectories from HF Hub
from huggingface_hub import hf_hub_download
Path("data").mkdir(exist_ok=True)
local_paths = []
for fn in TEACHER_FILES:
fn = fn.strip()
if not fn:
continue
print(f"Downloading {fn} from {TEACHER_DATA_REPO}...")
local = hf_hub_download(
repo_id=TEACHER_DATA_REPO,
filename=fn,
repo_type="dataset",
local_dir="data",
)
local_paths.append(local)
print(f"Downloaded {len(local_paths)} JSONL files")
# 3. Run SFT
sft_args = [
"python", "training/sft_prime.py",
"--teacher_jsonls", *local_paths,
"--output_dir", OUTPUT_DIR,
"--lora_rank", str(LORA_RANK),
"--learning_rate", str(LEARNING_RATE),
"--max_seq_length", str(MAX_SEQ_LENGTH),
"--epochs", str(EPOCHS),
]
if MAX_STEPS > 0:
sft_args.extend(["--max_steps", str(MAX_STEPS)])
run(sft_args)
# 4. Eval (optional — set SKIP_EVAL=1 to upload faster and run eval separately)
skip_eval = os.environ.get("SKIP_EVAL", "0") == "1"
if not skip_eval:
eval_args = [
"python", "training/inference_eval.py",
"--model_path", OUTPUT_DIR,
"--num_episodes", "5",
"--output_file", "eval_results.json",
]
run(eval_args)
else:
print("SKIP_EVAL=1: skipping embedded eval (run scripts/eval_on_hf.py separately)")
# 5. Upload to HF Hub
token = os.environ.get("HF_TOKEN")
if not token:
print("WARNING: HF_TOKEN not set, skipping upload")
print(f"Outputs at: {OUTPUT_DIR}")
return
from huggingface_hub import HfApi, login
login(token=token)
api = HfApi()
api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")
api.upload_folder(
folder_path=OUTPUT_DIR,
repo_id=MODEL_REPO,
repo_type="model",
commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
)
if not skip_eval and Path("eval_results.json").exists():
api.upload_file(
path_or_fileobj="eval_results.json",
path_in_repo="eval_results.json",
repo_id=MODEL_REPO,
repo_type="model",
)
print()
print("=" * 60)
print("DONE")
print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}")
if not skip_eval:
print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
else:
print(" Eval skipped — run scripts/eval_on_hf.py separately")
print("=" * 60)
if __name__ == "__main__":
main()