File size: 5,533 Bytes
ece0bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff20f02
 
 
 
 
 
 
 
 
 
 
 
ece0bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff20f02
 
 
 
 
 
 
ece0bbe
 
 
 
 
ff20f02
 
 
 
ece0bbe
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# /// script
# requires-python = ">=3.10"
# dependencies = [
#   "torch",
#   "transformers==4.56.2",
#   "trl==0.22.2",
#   "datasets",
#   "peft",
#   "accelerate",
#   "bitsandbytes",
#   "unsloth",
#   "openenv-core",
#   "fastapi",
#   "uvicorn",
#   "pydantic",
#   "huggingface_hub",
# ]
# ///
"""
HF Jobs orchestrator for SFT prime stage.

Submits the SFT prime training as an HF Jobs run. Clones the rhythm_env
HF Space, downloads the teacher trajectory JSONL files from a HF dataset
or model repo, runs training/sft_prime.py, and uploads the SFT'd model.

Submit from local with:
    hf jobs uv run --flavor a10g-large --secrets HF_TOKEN \\
        -e TEACHER_DATA_REPO=InosLihka/rhythm-env-teacher-trajectories \\
        -e MODEL_REPO_SUFFIX=sft-primed \\
        -e EPOCHS=2 \\
        -d scripts/sft_on_hf.py

Cost on a10g-large at $1.50/hr: ~$2-3 for ~30-45 min training.
"""

import json
import os
import shutil
import subprocess
import sys
from pathlib import Path

REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
WORK_DIR = "/tmp/rhythm_env"
OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythm-env-sft-primed"

# Teacher trajectory data must be uploaded to a HF dataset/model repo before
# this job runs (HF Jobs containers don't have access to local files). The
# repo should contain the teacher_*.jsonl files at its root.
TEACHER_DATA_REPO = os.environ.get(
    "TEACHER_DATA_REPO",
    "InosLihka/rhythm-env-teacher-trajectories",
)
TEACHER_FILES = os.environ.get(
    "TEACHER_FILES",
    "teacher_30ep_validation.jsonl,teacher_indist_30_99.jsonl,teacher_ood_10000_10049.jsonl",
).split(",")

EPOCHS = int(os.environ.get("EPOCHS", "2"))
MAX_STEPS = int(os.environ.get("MAX_STEPS", "-1"))  # -1 = use epochs
LORA_RANK = int(os.environ.get("LORA_RANK", "16"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "2048"))

SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "sft-primed")
DEFAULT_REPO = f"InosLihka/rhythm-env-meta-trained-{SUFFIX}"
MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)

print("=== SFT prime config ===")
print(f"  TEACHER_DATA_REPO: {TEACHER_DATA_REPO}")
print(f"  TEACHER_FILES:     {TEACHER_FILES}")
print(f"  EPOCHS={EPOCHS}, MAX_STEPS={MAX_STEPS}, LORA_RANK={LORA_RANK}")
print(f"  LR={LEARNING_RATE}, MAX_SEQ_LENGTH={MAX_SEQ_LENGTH}")
print(f"  MODEL_REPO={MODEL_REPO}")
print()


def run(cmd):
    print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
    subprocess.run(cmd, check=True)


def main():
    # 1. Clone repo
    if Path(WORK_DIR).exists():
        shutil.rmtree(WORK_DIR)
    run(["git", "clone", REPO_URL, WORK_DIR])
    os.chdir(WORK_DIR)
    sys.path.insert(0, WORK_DIR)
    sys.path.insert(0, os.path.join(WORK_DIR, "training"))

    # 2. Download teacher trajectories from HF Hub
    from huggingface_hub import hf_hub_download

    Path("data").mkdir(exist_ok=True)
    local_paths = []
    for fn in TEACHER_FILES:
        fn = fn.strip()
        if not fn:
            continue
        print(f"Downloading {fn} from {TEACHER_DATA_REPO}...")
        local = hf_hub_download(
            repo_id=TEACHER_DATA_REPO,
            filename=fn,
            repo_type="dataset",
            local_dir="data",
        )
        local_paths.append(local)
    print(f"Downloaded {len(local_paths)} JSONL files")

    # 3. Run SFT
    sft_args = [
        "python", "training/sft_prime.py",
        "--teacher_jsonls", *local_paths,
        "--output_dir", OUTPUT_DIR,
        "--lora_rank", str(LORA_RANK),
        "--learning_rate", str(LEARNING_RATE),
        "--max_seq_length", str(MAX_SEQ_LENGTH),
        "--epochs", str(EPOCHS),
    ]
    if MAX_STEPS > 0:
        sft_args.extend(["--max_steps", str(MAX_STEPS)])
    run(sft_args)

    # 4. Eval (optional — set SKIP_EVAL=1 to upload faster and run eval separately)
    skip_eval = os.environ.get("SKIP_EVAL", "0") == "1"
    if not skip_eval:
        eval_args = [
            "python", "training/inference_eval.py",
            "--model_path", OUTPUT_DIR,
            "--num_episodes", "5",
            "--output_file", "eval_results.json",
        ]
        run(eval_args)
    else:
        print("SKIP_EVAL=1: skipping embedded eval (run scripts/eval_on_hf.py separately)")

    # 5. Upload to HF Hub
    token = os.environ.get("HF_TOKEN")
    if not token:
        print("WARNING: HF_TOKEN not set, skipping upload")
        print(f"Outputs at: {OUTPUT_DIR}")
        return

    from huggingface_hub import HfApi, login
    login(token=token)
    api = HfApi()
    api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")

    api.upload_folder(
        folder_path=OUTPUT_DIR,
        repo_id=MODEL_REPO,
        repo_type="model",
        commit_message=f"SFT prime ({EPOCHS} epochs, lora r={LORA_RANK}) on teacher trajectories",
    )
    if not skip_eval and Path("eval_results.json").exists():
        api.upload_file(
            path_or_fileobj="eval_results.json",
            path_in_repo="eval_results.json",
            repo_id=MODEL_REPO,
            repo_type="model",
        )

    print()
    print("=" * 60)
    print("DONE")
    print(f"  SFT'd model: https://huggingface.co/{MODEL_REPO}")
    if not skip_eval:
        print(f"  Eval JSON:   https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
    else:
        print("  Eval skipped — run scripts/eval_on_hf.py separately")
    print("=" * 60)


if __name__ == "__main__":
    main()