Spaces:
Sleeping
Sleeping
| # /// script | |
| # requires-python = ">=3.10" | |
| # dependencies = [ | |
| # "torch", | |
| # "transformers==4.56.2", | |
| # "trl==0.22.2", | |
| # "datasets", | |
| # "peft", | |
| # "accelerate", | |
| # "bitsandbytes", | |
| # "unsloth", | |
| # "openenv-core", | |
| # "fastapi", | |
| # "uvicorn", | |
| # "pydantic", | |
| # "matplotlib", | |
| # "huggingface_hub", | |
| # ] | |
| # /// | |
| """ | |
| End-to-end training job for HF Jobs. | |
| Submit from local machine with: | |
| hf jobs uv run --flavor a10g-large --secrets HF_TOKEN scripts/train_on_hf.py | |
| What it does (no babysitting required): | |
| 1. Clone rhythm_env from HF Space (gets latest meta-RL code from main) | |
| 2. Generate dataset (continuous profiles, hint_fraction=0.15) | |
| 3. Train Qwen 2.5-3B + LoRA rank 8 via GRPO (1500 steps) | |
| 4. Run eval on all 3 conditions (discrete, in-dist, OOD) | |
| 5. Generate all 5 plots from log_history | |
| 6. Upload trained model + plots + eval JSON to a new HF Hub model repo | |
| Override defaults via env vars: | |
| MAX_STEPS, NUM_EPISODES, LORA_RANK, BETA, MODEL_REPO | |
| Estimated cost on a10g-large at $1.50/hr: ~$3 for 1500 steps (~2h). | |
| """ | |
| import json | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| # --------------------------------------------------------------------------- | |
| # Config (overridable via env vars) | |
| # --------------------------------------------------------------------------- | |
| REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env") | |
| WORK_DIR = "/tmp/rhythm_env" | |
| OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained" | |
| PLOTS_DIR = "/tmp/rhythm_env/plots" | |
| # FAST_MODE preset: ~10-15 min iteration on A100 large. | |
| # Use for hyperparameter sweeps and pipeline debugging. | |
| FAST_MODE = os.environ.get("FAST_MODE", "0") == "1" | |
| if FAST_MODE: | |
| # Iter 3 preset: 800 steps + 8 generations + LoRA 16 to escape mode collapse for real | |
| DEFAULTS = dict(MAX_STEPS=800, NUM_EPISODES=200, MAX_SAMPLES=2000, | |
| NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04, | |
| LEARNING_RATE=5e-5, EVAL_EPISODES=2) | |
| else: | |
| DEFAULTS = dict(MAX_STEPS=2000, NUM_EPISODES=400, MAX_SAMPLES=4000, | |
| NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04, | |
| LEARNING_RATE=5e-5, EVAL_EPISODES=5) | |
| MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"]))) | |
| NUM_EPISODES = int(os.environ.get("NUM_EPISODES", str(DEFAULTS["NUM_EPISODES"]))) | |
| MAX_SAMPLES = int(os.environ.get("MAX_SAMPLES", str(DEFAULTS["MAX_SAMPLES"]))) | |
| NUM_GENERATIONS = int(os.environ.get("NUM_GENERATIONS", str(DEFAULTS["NUM_GENERATIONS"]))) | |
| LORA_RANK = int(os.environ.get("LORA_RANK", str(DEFAULTS["LORA_RANK"]))) | |
| BETA = float(os.environ.get("BETA", str(DEFAULTS["BETA"]))) | |
| LEARNING_RATE = float(os.environ.get("LEARNING_RATE", str(DEFAULTS["LEARNING_RATE"]))) | |
| EVAL_EPISODES = int(os.environ.get("EVAL_EPISODES", str(DEFAULTS["EVAL_EPISODES"]))) | |
| # Each iteration uploads to a unique repo if MODEL_REPO_SUFFIX is set | |
| SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "") | |
| DEFAULT_REPO = "InosLihka/rhythm-env-meta-trained" + (f"-{SUFFIX}" if SUFFIX else "") | |
| MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO) | |
| print(f"=== Run config ===") | |
| print(f" FAST_MODE: {FAST_MODE}") | |
| print(f" MAX_STEPS={MAX_STEPS}, NUM_EPISODES={NUM_EPISODES}, MAX_SAMPLES={MAX_SAMPLES}") | |
| print(f" NUM_GENERATIONS={NUM_GENERATIONS}, LORA_RANK={LORA_RANK}, BETA={BETA}") | |
| print(f" LEARNING_RATE={LEARNING_RATE}, EVAL_EPISODES={EVAL_EPISODES}") | |
| print(f" MODEL_REPO={MODEL_REPO}") | |
| print() | |
| def run(cmd: list[str], **kw): | |
| """Run subprocess with logging.""" | |
| print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True) | |
| subprocess.run(cmd, check=True, **kw) | |
| def main(): | |
| # --------------------------------------------------------------- | |
| # 1. Clone the rhythm_env repo | |
| # --------------------------------------------------------------- | |
| if Path(WORK_DIR).exists(): | |
| shutil.rmtree(WORK_DIR) | |
| run(["git", "clone", REPO_URL, WORK_DIR]) | |
| os.chdir(WORK_DIR) | |
| sys.path.insert(0, WORK_DIR) | |
| sys.path.insert(0, os.path.join(WORK_DIR, "training")) | |
| # Verify meta-RL code is present | |
| dataset_py = Path("training/dataset.py").read_text() | |
| assert "profile_mode" in dataset_py, "Cloned repo doesn't have meta-RL code" | |
| print("OK: meta-RL code present in cloned repo") | |
| # --------------------------------------------------------------- | |
| # 2. Train | |
| # --------------------------------------------------------------- | |
| # MODEL_NAME env var lets us refine an existing trained model (e.g. SFT'd | |
| # checkpoint on HF Hub) instead of starting from the base Qwen. Default | |
| # is the original base model. | |
| base_model = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct") | |
| train_args = [ | |
| "python", "training/train.py", | |
| "--model_name", base_model, | |
| "--max_steps", str(MAX_STEPS), | |
| "--num_episodes", str(NUM_EPISODES), | |
| "--max_samples", str(MAX_SAMPLES), | |
| "--num_generations", str(NUM_GENERATIONS), | |
| "--lora_rank", str(LORA_RANK), | |
| "--beta", str(BETA), | |
| "--learning_rate", str(LEARNING_RATE), | |
| "--output_dir", OUTPUT_DIR, | |
| ] | |
| print(f"Starting from model: {base_model}") | |
| run(train_args) | |
| # --------------------------------------------------------------- | |
| # 3. Eval (3 conditions: discrete-3 / in-dist / OOD) | |
| # --------------------------------------------------------------- | |
| eval_args = [ | |
| "python", "training/inference_eval.py", | |
| "--model_path", OUTPUT_DIR, | |
| "--num_episodes", str(EVAL_EPISODES), | |
| "--output_file", "eval_results.json", | |
| ] | |
| run(eval_args) | |
| # --------------------------------------------------------------- | |
| # 4. Generate plots from saved log_history | |
| # --------------------------------------------------------------- | |
| Path(PLOTS_DIR).mkdir(exist_ok=True) | |
| log_path = os.path.join(OUTPUT_DIR, "log_history.json") | |
| if Path(log_path).exists(): | |
| run(["python", "scripts/plot_from_log.py", "--log", log_path, "--out", PLOTS_DIR]) | |
| else: | |
| print(f"WARNING: log_history.json not found at {log_path}") | |
| # --------------------------------------------------------------- | |
| # 5. Upload everything to HF Hub | |
| # --------------------------------------------------------------- | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| print("WARNING: HF_TOKEN not set, skipping upload") | |
| print(f"Outputs in: {OUTPUT_DIR}") | |
| return | |
| from huggingface_hub import HfApi, login | |
| login(token=token) | |
| api = HfApi() | |
| api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model") | |
| # Upload trained model + config + log_history | |
| api.upload_folder( | |
| folder_path=OUTPUT_DIR, | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| commit_message=f"Trained {MAX_STEPS}-step GRPO meta-RL agent", | |
| ) | |
| # Upload eval JSON | |
| api.upload_file( | |
| path_or_fileobj="eval_results.json", | |
| path_in_repo="eval_results.json", | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| ) | |
| # Upload plots if generated | |
| if Path(PLOTS_DIR).exists() and any(Path(PLOTS_DIR).iterdir()): | |
| api.upload_folder( | |
| folder_path=PLOTS_DIR, | |
| path_in_repo="plots", | |
| repo_id=MODEL_REPO, | |
| repo_type="model", | |
| ) | |
| print() | |
| print("=" * 60) | |
| print("DONE") | |
| print(f" Trained model: https://huggingface.co/{MODEL_REPO}") | |
| print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json") | |
| print(f" Plots: https://huggingface.co/{MODEL_REPO}/tree/main/plots") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |