File size: 7,688 Bytes
73c7ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a865f8
 
 
 
 
64d24b3
 
 
1a865f8
 
64d24b3
 
1a865f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73c7ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0ca22d
 
 
 
 
73c7ea0
 
f0ca22d
73c7ea0
 
1a865f8
 
73c7ea0
 
1a865f8
73c7ea0
 
f0ca22d
73c7ea0
 
 
 
 
 
 
 
1a865f8
73c7ea0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# /// script
# requires-python = ">=3.10"
# dependencies = [
#   "torch",
#   "transformers==4.56.2",
#   "trl==0.22.2",
#   "datasets",
#   "peft",
#   "accelerate",
#   "bitsandbytes",
#   "unsloth",
#   "openenv-core",
#   "fastapi",
#   "uvicorn",
#   "pydantic",
#   "matplotlib",
#   "huggingface_hub",
# ]
# ///
"""
End-to-end training job for HF Jobs.

Submit from local machine with:
    hf jobs uv run --flavor a10g-large --secrets HF_TOKEN scripts/train_on_hf.py

What it does (no babysitting required):
  1. Clone rhythm_env from HF Space (gets latest meta-RL code from main)
  2. Generate dataset (continuous profiles, hint_fraction=0.15)
  3. Train Qwen 2.5-3B + LoRA rank 8 via GRPO (1500 steps)
  4. Run eval on all 3 conditions (discrete, in-dist, OOD)
  5. Generate all 5 plots from log_history
  6. Upload trained model + plots + eval JSON to a new HF Hub model repo

Override defaults via env vars:
    MAX_STEPS, NUM_EPISODES, LORA_RANK, BETA, MODEL_REPO

Estimated cost on a10g-large at $1.50/hr: ~$3 for 1500 steps (~2h).
"""

import json
import os
import shutil
import subprocess
import sys
from pathlib import Path

# ---------------------------------------------------------------------------
# Config (overridable via env vars)
# ---------------------------------------------------------------------------
REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
WORK_DIR = "/tmp/rhythm_env"
OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained"
PLOTS_DIR = "/tmp/rhythm_env/plots"

# FAST_MODE preset: ~10-15 min iteration on A100 large.
# Use for hyperparameter sweeps and pipeline debugging.
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"

if FAST_MODE:
    # Iter 3 preset: 800 steps + 8 generations + LoRA 16 to escape mode collapse for real
    DEFAULTS = dict(MAX_STEPS=800, NUM_EPISODES=200, MAX_SAMPLES=2000,
                    NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
                    LEARNING_RATE=5e-5, EVAL_EPISODES=2)
else:
    DEFAULTS = dict(MAX_STEPS=2000, NUM_EPISODES=400, MAX_SAMPLES=4000,
                    NUM_GENERATIONS=8, LORA_RANK=16, BETA=0.04,
                    LEARNING_RATE=5e-5, EVAL_EPISODES=5)

MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
NUM_EPISODES = int(os.environ.get("NUM_EPISODES", str(DEFAULTS["NUM_EPISODES"])))
MAX_SAMPLES = int(os.environ.get("MAX_SAMPLES", str(DEFAULTS["MAX_SAMPLES"])))
NUM_GENERATIONS = int(os.environ.get("NUM_GENERATIONS", str(DEFAULTS["NUM_GENERATIONS"])))
LORA_RANK = int(os.environ.get("LORA_RANK", str(DEFAULTS["LORA_RANK"])))
BETA = float(os.environ.get("BETA", str(DEFAULTS["BETA"])))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", str(DEFAULTS["LEARNING_RATE"])))
EVAL_EPISODES = int(os.environ.get("EVAL_EPISODES", str(DEFAULTS["EVAL_EPISODES"])))

# Each iteration uploads to a unique repo if MODEL_REPO_SUFFIX is set
SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "")
DEFAULT_REPO = "InosLihka/rhythm-env-meta-trained" + (f"-{SUFFIX}" if SUFFIX else "")
MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)

print(f"=== Run config ===")
print(f"  FAST_MODE: {FAST_MODE}")
print(f"  MAX_STEPS={MAX_STEPS}, NUM_EPISODES={NUM_EPISODES}, MAX_SAMPLES={MAX_SAMPLES}")
print(f"  NUM_GENERATIONS={NUM_GENERATIONS}, LORA_RANK={LORA_RANK}, BETA={BETA}")
print(f"  LEARNING_RATE={LEARNING_RATE}, EVAL_EPISODES={EVAL_EPISODES}")
print(f"  MODEL_REPO={MODEL_REPO}")
print()


def run(cmd: list[str], **kw):
    """Run subprocess with logging."""
    print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
    subprocess.run(cmd, check=True, **kw)


def main():
    # ---------------------------------------------------------------
    # 1. Clone the rhythm_env repo
    # ---------------------------------------------------------------
    if Path(WORK_DIR).exists():
        shutil.rmtree(WORK_DIR)
    run(["git", "clone", REPO_URL, WORK_DIR])
    os.chdir(WORK_DIR)
    sys.path.insert(0, WORK_DIR)
    sys.path.insert(0, os.path.join(WORK_DIR, "training"))

    # Verify meta-RL code is present
    dataset_py = Path("training/dataset.py").read_text()
    assert "profile_mode" in dataset_py, "Cloned repo doesn't have meta-RL code"
    print("OK: meta-RL code present in cloned repo")

    # ---------------------------------------------------------------
    # 2. Train
    # ---------------------------------------------------------------
    # MODEL_NAME env var lets us refine an existing trained model (e.g. SFT'd
    # checkpoint on HF Hub) instead of starting from the base Qwen. Default
    # is the original base model.
    base_model = os.environ.get("MODEL_NAME", "unsloth/Qwen2.5-3B-Instruct")

    train_args = [
        "python", "training/train.py",
        "--model_name", base_model,
        "--max_steps", str(MAX_STEPS),
        "--num_episodes", str(NUM_EPISODES),
        "--max_samples", str(MAX_SAMPLES),
        "--num_generations", str(NUM_GENERATIONS),
        "--lora_rank", str(LORA_RANK),
        "--beta", str(BETA),
        "--learning_rate", str(LEARNING_RATE),
        "--output_dir", OUTPUT_DIR,
    ]
    print(f"Starting from model: {base_model}")
    run(train_args)

    # ---------------------------------------------------------------
    # 3. Eval (3 conditions: discrete-3 / in-dist / OOD)
    # ---------------------------------------------------------------
    eval_args = [
        "python", "training/inference_eval.py",
        "--model_path", OUTPUT_DIR,
        "--num_episodes", str(EVAL_EPISODES),
        "--output_file", "eval_results.json",
    ]
    run(eval_args)

    # ---------------------------------------------------------------
    # 4. Generate plots from saved log_history
    # ---------------------------------------------------------------
    Path(PLOTS_DIR).mkdir(exist_ok=True)
    log_path = os.path.join(OUTPUT_DIR, "log_history.json")
    if Path(log_path).exists():
        run(["python", "scripts/plot_from_log.py", "--log", log_path, "--out", PLOTS_DIR])
    else:
        print(f"WARNING: log_history.json not found at {log_path}")

    # ---------------------------------------------------------------
    # 5. Upload everything to HF Hub
    # ---------------------------------------------------------------
    token = os.environ.get("HF_TOKEN")
    if not token:
        print("WARNING: HF_TOKEN not set, skipping upload")
        print(f"Outputs in: {OUTPUT_DIR}")
        return

    from huggingface_hub import HfApi, login
    login(token=token)
    api = HfApi()
    api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")

    # Upload trained model + config + log_history
    api.upload_folder(
        folder_path=OUTPUT_DIR,
        repo_id=MODEL_REPO,
        repo_type="model",
        commit_message=f"Trained {MAX_STEPS}-step GRPO meta-RL agent",
    )

    # Upload eval JSON
    api.upload_file(
        path_or_fileobj="eval_results.json",
        path_in_repo="eval_results.json",
        repo_id=MODEL_REPO,
        repo_type="model",
    )

    # Upload plots if generated
    if Path(PLOTS_DIR).exists() and any(Path(PLOTS_DIR).iterdir()):
        api.upload_folder(
            folder_path=PLOTS_DIR,
            path_in_repo="plots",
            repo_id=MODEL_REPO,
            repo_type="model",
        )

    print()
    print("=" * 60)
    print("DONE")
    print(f"  Trained model: https://huggingface.co/{MODEL_REPO}")
    print(f"  Eval JSON:     https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
    print(f"  Plots:         https://huggingface.co/{MODEL_REPO}/tree/main/plots")
    print("=" * 60)


if __name__ == "__main__":
    main()