rhythm_env / scripts /train_on_hf.py
InosLihka's picture
Refactor grader to use openenv.core.rubrics.WeightedSum + Rubric subclasses
f0ca22d
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "transformers==4.56.2",
# "trl==0.22.2",
# "datasets",
# "peft",
# "accelerate",
# "bitsandbytes",
# "unsloth",
# "openenv-core",
# "fastapi",
# "uvicorn",
# "pydantic",
# "matplotlib",
# "huggingface_hub",
# ]
# ///
"""
End-to-end training job for HF Jobs.
Submit from local machine with:
hf jobs uv run --flavor a10g-large --secrets HF_TOKEN scripts/train_on_hf.py
What it does (no babysitting required):
1. Clone rhythm_env from HF Space (gets latest meta-RL code from main)
2. Generate dataset (continuous profiles, hint_fraction=0.15)
3. Train Qwen 2.5-3B + LoRA rank 8 via GRPO (1500 steps)
4. Run eval on all 3 conditions (discrete, in-dist, OOD)
5. Generate all 5 plots from log_history
6. Upload trained model + plots + eval JSON to a new HF Hub model repo
Override defaults via env vars:
MAX_STEPS, NUM_EPISODES, LORA_RANK, BETA, MODEL_REPO
Estimated cost on a10g-large at $1.50/hr: ~$3 for 1500 steps (~2h).
"""
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
# ---------------------------------------------------------------------------
# Config (overridable via env vars)
# ---------------------------------------------------------------------------
REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
WORK_DIR = "/tmp/rhythm_env"
OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained"
PLOTS_DIR = "/tmp/rhythm_env/plots"
# FAST_MODE preset: ~10-15 min iteration on A100 large.
# Use for hyperparameter sweeps and pipeline debugging.
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
if FAST_MODE:
# Iter 3 preset: 800 steps + 8 generations + LoRA 16 to escape mode collapse for real
DEFAULTS = dict(MAX_STEPS=800, NUM_EPISODES=200, MAX_SAMPLES=2000,
NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
else:
DEFAULTS = dict(MAX_STEPS=2000, NUM_EPISODES=400, MAX_SAMPLES=4000,
NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
LEARNING_RATE=5e-5, EVAL_EPISODES=5)
MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
NUM_EPISODES = int(os.environ.get("NUM_EPISODES", str(DEFAULTS["NUM_EPISODES"])))
MAX_SAMPLES = int(os.environ.get("MAX_SAMPLES", str(DEFAULTS["MAX_SAMPLES"])))
NUM_GENERATIONS = int(os.environ.get("NUM_GENERATIONS", str(DEFAULTS["NUM_GENERATIONS"])))
LORA_RANK = int(os.environ.get("LORA_RANK", str(DEFAULTS["LORA_RANK"])))
BETA = float(os.environ.get("BETA", str(DEFAULTS["BETA"])))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", str(DEFAULTS["LEARNING_RATE"])))
EVAL_EPISODES = int(os.environ.get("EVAL_EPISODES", str(DEFAULTS["EVAL_EPISODES"])))
# Each iteration uploads to a unique repo if MODEL_REPO_SUFFIX is set
SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "")
DEFAULT_REPO = "InosLihka/rhythm-env-meta-trained" + (f"-{SUFFIX}" if SUFFIX else "")
MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)
print(f"=== Run config ===")
print(f" FAST_MODE: {FAST_MODE}")
print(f" MAX_STEPS={MAX_STEPS}, NUM_EPISODES={NUM_EPISODES}, MAX_SAMPLES={MAX_SAMPLES}")
print(f" NUM_GENERATIONS={NUM_GENERATIONS}, LORA_RANK={LORA_RANK}, BETA={BETA}")
print(f" LEARNING_RATE={LEARNING_RATE}, EVAL_EPISODES={EVAL_EPISODES}")
print(f" MODEL_REPO={MODEL_REPO}")
print()
def run(cmd: list[str], **kw):
"""Run subprocess with logging."""
print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
subprocess.run(cmd, check=True, **kw)
def main():
# ---------------------------------------------------------------
# 1. Clone the rhythm_env repo
# ---------------------------------------------------------------
if Path(WORK_DIR).exists():
shutil.rmtree(WORK_DIR)
run(["git", "clone", REPO_URL, WORK_DIR])
os.chdir(WORK_DIR)
sys.path.insert(0, WORK_DIR)
sys.path.insert(0, os.path.join(WORK_DIR, "training"))
# Verify meta-RL code is present
dataset_py = Path("training/dataset.py").read_text()
assert "profile_mode" in dataset_py, "Cloned repo doesn't have meta-RL code"
print("OK: meta-RL code present in cloned repo")
# ---------------------------------------------------------------
# 2. Train
# ---------------------------------------------------------------
# MODEL_NAME env var lets us refine an existing trained model (e.g. SFT'd
# checkpoint on HF Hub) instead of starting from the base Qwen. Default
# is the original base model.
base_model = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct")
train_args = [
"python", "training/train.py",
"--model_name", base_model,
"--max_steps", str(MAX_STEPS),
"--num_episodes", str(NUM_EPISODES),
"--max_samples", str(MAX_SAMPLES),
"--num_generations", str(NUM_GENERATIONS),
"--lora_rank", str(LORA_RANK),
"--beta", str(BETA),
"--learning_rate", str(LEARNING_RATE),
"--output_dir", OUTPUT_DIR,
]
print(f"Starting from model: {base_model}")
run(train_args)
# ---------------------------------------------------------------
# 3. Eval (3 conditions: discrete-3 / in-dist / OOD)
# ---------------------------------------------------------------
eval_args = [
"python", "training/inference_eval.py",
"--model_path", OUTPUT_DIR,
"--num_episodes", str(EVAL_EPISODES),
"--output_file", "eval_results.json",
]
run(eval_args)
# ---------------------------------------------------------------
# 4. Generate plots from saved log_history
# ---------------------------------------------------------------
Path(PLOTS_DIR).mkdir(exist_ok=True)
log_path = os.path.join(OUTPUT_DIR, "log_history.json")
if Path(log_path).exists():
run(["python", "scripts/plot_from_log.py", "--log", log_path, "--out", PLOTS_DIR])
else:
print(f"WARNING: log_history.json not found at {log_path}")
# ---------------------------------------------------------------
# 5. Upload everything to HF Hub
# ---------------------------------------------------------------
token = os.environ.get("HF_TOKEN")
if not token:
print("WARNING: HF_TOKEN not set, skipping upload")
print(f"Outputs in: {OUTPUT_DIR}")
return
from huggingface_hub import HfApi, login
login(token=token)
api = HfApi()
api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")
# Upload trained model + config + log_history
api.upload_folder(
folder_path=OUTPUT_DIR,
repo_id=MODEL_REPO,
repo_type="model",
commit_message=f"Trained {MAX_STEPS}-step GRPO meta-RL agent",
)
# Upload eval JSON
api.upload_file(
path_or_fileobj="eval_results.json",
path_in_repo="eval_results.json",
repo_id=MODEL_REPO,
repo_type="model",
)
# Upload plots if generated
if Path(PLOTS_DIR).exists() and any(Path(PLOTS_DIR).iterdir()):
api.upload_folder(
folder_path=PLOTS_DIR,
path_in_repo="plots",
repo_id=MODEL_REPO,
repo_type="model",
)
print()
print("=" * 60)
print("DONE")
print(f" Trained model: https://huggingface.co/{MODEL_REPO}")
print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
print(f" Plots: https://huggingface.co/{MODEL_REPO}/tree/main/plots")
print("=" * 60)
if __name__ == "__main__":
main()