Spaces:
Sleeping
Sleeping
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "torch", | |
| # "transformers==4.56.2", | |
| # "trl==0.22.2", | |
| # "datasets", | |
| # "peft", | |
| # "accelerate", | |
| # "bitsandbytes", | |
| # "unsloth", | |
| # "openenv-core", | |
| # "fastapi", | |
| # "uvicorn", | |
| # "pydantic", | |
| # "huggingface_hub", | |
| # ] | |
| # /// | |
| """ | |
| HF Jobs orchestrator for SFT prime stage. | |
| Submits the SFT prime training as an HF Jobs run. Clones the rhythm_env | |
| HF Space, downloads the teacher trajectory JSONL files from a HF dataset | |
| or model repo, runs training/sft_prime.py, and uploads the SFT'd model. | |
| Submit from local with: | |
| hf jobs uv run --flavor a10g-large --secrets HF_TOKEN \\ | |
| -e TEACHER_DATA_REPO=InosLihka/rhythm-env-teacher-trajectories \\ | |
| -e MODEL_REPO_SUFFIX=sft-primed \\ | |
| -e EPOCHS=2 \\ | |
| -d scripts/sft_on_hf.py | |
| Cost on a10g-large at $1.50/hr: ~$2-3 for ~30-45 min training. | |
| """ | |
| import json | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env") | |
| WORK_DIR = "/tmp/rhythm_env" | |
| OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythm-env-sft-primed" | |
| # Teacher trajectory data must be uploaded to a HF dataset/model repo before | |
| # this job runs (HF Jobs containers don't have access to local files). The | |
| # repo should contain the teacher_*.jsonl files at its root. | |
| TEACHER_DATA_REPO = os.environ.get( | |
| "TEACHER_DATA_REPO", | |
| "InosLihka/rhythm-env-teacher-trajectories", | |
| ) | |
| TEACHER_FILES = os.environ.get( | |
| "TEACHER_FILES", | |
| "teacher_30ep_validation.jsonl,teacher_indist_30_99.jsonl,teacher_ood_10000_10049.jsonl", | |
| ).split(",") | |
| EPOCHS = int(os.environ.get("EPOCHS", "2")) | |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", "-1")) # -1 = use epochs | |
| LORA_RANK = int(os.environ.get("LORA_RANK", "16")) | |
| LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4")) | |
| MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "2048")) | |
| SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "sft-primed") | |
| DEFAULT_REPO = f"InosLihka/rhythm-env-meta-trained-{SUFFIX}" | |
| MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO) | |
| print("=== SFT prime config ===") | |
| print(f" TEACHER_DATA_REPO: {TEACHER_DATA_REPO}") | |
| print(f" TEACHER_FILES: {TEACHER_FILES}") | |
| print(f" EPOCHS={EPOCHS}, MAX_STEPS={MAX_STEPS}, LORA_RANK={LORA_RANK}") | |
| print(f" LR={LEARNING_RATE}, MAX_SEQ_LENGTH={MAX_SEQ_LENGTH}") | |
| print(f" MODEL_REPO={MODEL_REPO}") | |
| print() | |
| def run(cmd): | |
| print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True) | |
| subprocess.run(cmd, check=True) | |
| def main(): | |
| # 1. Clone repo | |
| if Path(WORK_DIR).exists(): | |
| shutil.rmtree(WORK_DIR) | |
| run(["git", "clone", REPO_URL, WORK_DIR]) | |
| os.chdir(WORK_DIR) | |
| sys.path.insert(0, WORK_DIR) | |
| sys.path.insert(0, os.path.join(WORK_DIR, "training")) | |
| # 2. Download teacher trajectories from HF Hub | |
| from huggingface_hub import hf_hub_download | |
| Path("data").mkdir(exist_ok=True) | |
| local_paths = [] | |
| for fn in TEACHER_FILES: | |
| fn = fn.strip() | |
| if not fn: | |
| continue | |
| print(f"Downloading {fn} from {TEACHER_DATA_REPO}...") | |
| local = hf_hub_download( | |
| repo_id=TEACHER_DATA_REPO, | |
| filename=fn, | |
| repo_type="dataset", | |
| local_dir="data", | |
| ) | |
| local_paths.append(local) | |
| print(f"Downloaded {len(local_paths)} JSONL files") | |
| # 3. Run SFT | |
| sft_args = [ | |
| "python", "training/sft_prime.py", | |
| "--teacher_jsonls", *local_paths, | |
| "--output_dir", OUTPUT_DIR, | |
| "--lora_rank", str(LORA_RANK), | |
| "--learning_rate", str(LEARNING_RATE), | |
| "--max_seq_length", str(MAX_SEQ_LENGTH), | |
| "--epochs", str(EPOCHS), | |
| ] | |
| if MAX_STEPS > 0: | |
| sft_args.extend(["--max_steps", str(MAX_STEPS)]) | |
| run(sft_args) | |
| # 4. Eval (optional — set SKIP_EVAL=1 to upload faster and run eval separately) | |
| skip_eval = os.environ.get("SKIP_EVAL", "0") == "1" | |
| if not skip_eval: | |
| eval_args = [ | |
| "python", "training/inference_eval.py", | |
| "--model_path", OUTPUT_DIR, | |
| "--num_episodes", "5", | |
| "--output_file", "eval_results.json", | |
| ] | |
| run(eval_args) | |
| else: | |
| print("SKIP_EVAL=1: skipping embedded eval (run scripts/eval_on_hf.py separately)") | |
| # 5. Upload to HF Hub | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("WARNING: HF_TOKEN not set, skipping upload") | |
| print(f"Outputs at: {OUTPUT_DIR}") | |
| return | |
| from huggingface_hub import HfApi, login | |
| login(token=token) | |
| api = HfApi() | |
| api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model") | |
| api.upload_folder( | |
| folder_path=OUTPUT_DIR, | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories", | |
| ) | |
| if not skip_eval and Path("eval_results.json").exists(): | |
| api.upload_file( | |
| path_or_fileobj="eval_results.json", | |
| path_in_repo="eval_results.json", | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| ) | |
| print() | |
| print("=" * 60) | |
| print("DONE") | |
| print(f" SFT'd model: https://huggingface.co/{MODEL_REPO}") | |
| if not skip_eval: | |
| print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json") | |
| else: | |
| print(" Eval skipped — run scripts/eval_on_hf.py separately") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |