File size: 11,748 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
#!/usr/bin/env python
"""Job-side training entrypoint for ForgeEnv on HF Jobs A100.

Submitted via ``scripts/submit_training_job.py``. The launcher fills in
``HF_TOKEN``, ``HF_USERNAME``, ``ENV_URL`` as Job env vars. The job:

1. Clones ``<HF_USERNAME>/forgeenv-source`` (full project tree).
2. Installs the repo with training extras.
3. Sanity-pings the live env Space.
4. Runs warm-start SFT (TRL SFTTrainer + Unsloth, 4-bit LoRA).
5. Runs GRPO repair (TRL GRPOTrainer) starting from the SFT adapter.
6. Generates plots via ``forgeenv.training.plots``.
7. Pushes the LoRA + ``repair_library.json`` + plots to
   ``<HF_USERNAME>/forgeenv-repair-agent``.

The script is linear and prints big section markers so the streaming log
is easy to follow from the launcher.
"""
from __future__ import annotations

import json
import os
import shutil
import subprocess
import sys
from pathlib import Path


def _sh(cmd: list[str], **kwargs) -> None:
    print(f"[job] $ {' '.join(cmd)}", flush=True)
    subprocess.check_call(cmd, **kwargs)


def step(label: str) -> None:
    print(f"\n========== {label} ==========\n", flush=True)


HF_TOKEN = os.environ["HF_TOKEN"]
HF_USERNAME = os.environ.get("HF_USERNAME", "akhiilll")
ENV_URL = os.environ.get("ENV_URL", f"https://{HF_USERNAME}-forgeenv.hf.space")
SOURCE_REPO = os.environ.get("SOURCE_REPO", f"{HF_USERNAME}/forgeenv-source")
MODEL_REPO = os.environ.get("MODEL_REPO", f"{HF_USERNAME}/forgeenv-repair-agent")
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
SFT_STEPS = int(os.environ.get("SFT_STEPS", "1000"))
GRPO_STEPS = int(os.environ.get("GRPO_STEPS", "200"))

WORK = Path("/tmp/forgeenv_work")
WORK.mkdir(parents=True, exist_ok=True)
OUT = WORK / "outputs"
OUT.mkdir(parents=True, exist_ok=True)
SFT_DIR = OUT / "sft"
GRPO_DIR = OUT / "grpo"
PLOTS_DIR = OUT / "plots"
PLOTS_DIR.mkdir(parents=True, exist_ok=True)


step("0. clone source from Hub")
src_dir = WORK / "src"
if src_dir.exists():
    shutil.rmtree(src_dir)
_sh([
    "git", "clone",
    f"https://USER:{HF_TOKEN}@huggingface.co/{SOURCE_REPO}",
    str(src_dir),
])
# Belt-and-braces: prepend the source dir to sys.path so `import forgeenv`
# works even if `pip install -e` doesn't persist inside the uv-managed
# venv. We still run pip install for any setuptools side-effects.
sys.path.insert(0, str(src_dir))

step("1. pin torch (cu124) + install GPU-stable deps")
# Force a CUDA 12.4 torch wheel BEFORE anything else so other packages'
# resolvers don't pull a cu130 wheel that mismatches the host driver
# (Error 802 on some HF Job flavors).  TRL 1.2+ imports ``FSDPModule`` from
# ``torch.distributed.fsdp``, which exists only in PyTorch >= 2.6 — do not
# pin to 2.5.x.
_sh([
    sys.executable, "-m", "pip", "install",
    "--index-url", "https://download.pytorch.org/whl/cu124",
    "torch==2.6.0", "torchvision==0.21.0",
])
# `--no-deps` on openenv-core: it pins a different transformers/torch
# stack that we don't want. We still need its *runtime* imports:
# ``import forgeenv`` -> ``ForgeEnvironment`` -> ``openenv.core`` pulls in
# ``fastmcp`` (and friends) from ``openenv.core.env_server``.
_sh([
    sys.executable, "-m", "pip", "install", "--no-deps",
    "openenv-core>=0.2.0",
])
_sh([
    sys.executable, "-m", "pip", "install",
    "fastmcp>=3.0.0",
    "gradio>=4.0.0",
    "openai>=2.7.2",
    "tomli>=2.3.0",
    "tomli-w>=1.2.0",
    "websockets>=15.0.1",
])
_sh([
    sys.executable, "-m", "pip", "install",
    "trl==1.2.0", "peft", "accelerate", "datasets",
    "bitsandbytes",
    "matplotlib", "pyyaml", "nltk", "scikit-learn",
    "fastapi", "uvicorn", "pydantic", "requests",
    "sentencepiece", "protobuf",
])
try:
    # --no-deps is critical: prevents unsloth from re-resolving torch.
    _sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"])
except subprocess.CalledProcessError:
    print("[job] WARN: unsloth install failed — trainer will use plain HF.", flush=True)

import torch  # noqa: E402

print(f"[job] torch: {torch.__version__}", flush=True)
print(f"[job] CUDA available: {torch.cuda.is_available()}", flush=True)
if torch.cuda.is_available():
    print(f"[job] GPU: {torch.cuda.get_device_name(0)}", flush=True)
    print(
        f"[job] VRAM: "
        f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB",
        flush=True,
    )
else:
    raise SystemExit("[job] FATAL: no CUDA — refusing to run training on CPU.")

step("2. ping live env Space + verify forgeenv import")
import requests  # noqa: E402

try:
    r = requests.get(f"{ENV_URL}/health", timeout=20)
    print(f"[job] env /health -> {r.status_code} {r.text}", flush=True)
except Exception as e:  # noqa: BLE001
    print(f"[job] WARN: env ping failed: {e}", flush=True)

# Fail fast if forgeenv isn't on the path -- much cheaper to discover
# this here than after 8+ minutes of SFT.
import forgeenv  # noqa: F401, E402
from forgeenv.training.grpo_repair import run_grpo  # noqa: F401, E402

print("[job] forgeenv import OK", flush=True)

step("3. SFT: load Qwen + LoRA via Unsloth, train on warm-start pairs")
from unsloth import FastLanguageModel  # noqa: E402

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL,
    max_seq_length=2048,
    load_in_4bit=True,
    dtype=None,
    token=HF_TOKEN,
)
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    use_gradient_checkpointing="unsloth",
)
print(
    f"[job] trainable params: "
    f"{model.num_parameters(only_trainable=True):,}",
    flush=True,
)

import datasets as ds  # noqa: E402
from trl import SFTConfig, SFTTrainer  # noqa: E402

sft_jsonl = src_dir / "warmstart" / "data" / "repair_pairs.jsonl"
if not sft_jsonl.exists():
    sft_jsonl = src_dir / "warmstart" / "data" / "drift_pairs.jsonl"
print(f"[job] SFT pairs: {sft_jsonl}", flush=True)


def _format_chat(example):
    msgs = example.get("messages")
    if not msgs:
        return {"text": ""}
    return {
        "text": tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=False
        )
    }


sft_ds = ds.load_dataset("json", data_files=str(sft_jsonl), split="train")
sft_ds = sft_ds.map(_format_chat, remove_columns=sft_ds.column_names)

sft_trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=sft_ds,
    args=SFTConfig(
        output_dir=str(SFT_DIR),
        max_steps=SFT_STEPS,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        logging_steps=25,
        save_steps=max(250, SFT_STEPS // 4),
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        max_length=2048,
        packing=True,
        packing_strategy="bfd",
        report_to=[],
    ),
)
sft_trainer.train()
model.save_pretrained(str(SFT_DIR))
tokenizer.save_pretrained(str(SFT_DIR))

# free memory before GRPO reloads the model
del sft_trainer, model, tokenizer
import gc

gc.collect()
torch.cuda.empty_cache()

step("4. GRPO repair training (resumes from SFT adapter)")
from forgeenv.training.grpo_repair import run_grpo  # noqa: E402

run_grpo(
    base_model=BASE_MODEL,
    adapter_path=str(SFT_DIR),
    output_dir=str(GRPO_DIR),
    total_episodes=GRPO_STEPS,
    group_size=4,
    learning_rate=5e-6,
)

step("5. generate plots from training logs")
from forgeenv.training.plots import (  # noqa: E402
    plot_baseline_vs_trained,
    plot_reward_curve,
    plot_success_rate_by_category,
)

# TRL writes trainer_state.json under each checkpoint dir, not directly
# at output_dir. Pick the latest checkpoint, fall back to output_dir.
def _find_trainer_state(grpo_dir: Path) -> Optional[Path]:  # type: ignore[name-defined]
    direct = grpo_dir / "trainer_state.json"
    if direct.exists():
        return direct
    ckpts = sorted(
        (p for p in grpo_dir.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
        key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
    )
    return (ckpts[-1] / "trainer_state.json") if ckpts else None


from typing import Optional  # noqa: E402

trainer_state = _find_trainer_state(GRPO_DIR)
print(f"[job] trainer_state path: {trainer_state}", flush=True)
training_rewards: list[float] = []
if trainer_state is not None and trainer_state.exists():
    state = json.loads(trainer_state.read_text())
    log_history = state.get("log_history", [])
    print(f"[job] log_history rows: {len(log_history)}", flush=True)
    if log_history:
        sample_keys = sorted(set().union(*(log.keys() for log in log_history)))
        print(f"[job] log keys present: {sample_keys}", flush=True)
    for log in log_history:
        # TRL emits a few different reward keys depending on version;
        # try the most specific first, then fall back.
        candidates = [
            "rewards/reward_repair_function/mean",
            "rewards/mean",
            "reward",
            "train/reward",
        ]
        # also pick up any key matching rewards/<name>/mean
        for k in list(log.keys()):
            if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates:
                candidates.append(k)
        for k in candidates:
            if k in log:
                training_rewards.append(float(log[k]))
                break
print(f"[job] {len(training_rewards)} reward log points", flush=True)
if training_rewards:
    print(
        f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}",
        flush=True,
    )

plot_reward_curve(
    training_rewards or [0.0],
    str(PLOTS_DIR / "training_reward_curve.png"),
)
# we keep the CPU artifacts for baseline_vs_trained; if you want a real
# eval pass post-training, run the rollout helper here. The artifact
# generator already produced these from the dry-run.
src_plots = src_dir / "artifacts" / "plots"
for name in ("baseline_vs_trained.png", "success_by_category.png"):
    src_p = src_plots / name
    if src_p.exists():
        shutil.copy(src_p, PLOTS_DIR / name)

step("6. push LoRA + artifacts to Hub")
final_dir = OUT / "final"
final_dir.mkdir(parents=True, exist_ok=True)
for item in GRPO_DIR.iterdir():
    if item.is_file() and (
        item.name.startswith("adapter_")
        or item.name.startswith("tokenizer")
        or item.name in {"special_tokens_map.json", "vocab.json", "merges.txt"}
    ):
        shutil.copy(item, final_dir / item.name)

repair_lib = src_dir / "artifacts" / "repair_library.json"
if repair_lib.exists():
    shutil.copy(repair_lib, final_dir / "repair_library.json")

from huggingface_hub import HfApi  # noqa: E402

api = HfApi()
api.create_repo(
    repo_id=MODEL_REPO,
    repo_type="model",
    token=HF_TOKEN,
    exist_ok=True,
    private=False,
)
api.upload_folder(
    folder_path=str(final_dir),
    repo_id=MODEL_REPO,
    repo_type="model",
    token=HF_TOKEN,
    commit_message=f"GRPO LoRA (sft={SFT_STEPS}, grpo={GRPO_STEPS})",
    ignore_patterns=["__pycache__", "*.pyc"],
)
api.upload_folder(
    folder_path=str(PLOTS_DIR),
    repo_id=MODEL_REPO,
    repo_type="model",
    token=HF_TOKEN,
    path_in_repo="plots",
    commit_message="Training plots",
)

print(
    f"\n[job] DONE. Model live at https://huggingface.co/{MODEL_REPO}",
    flush=True,
)
print(
    json.dumps(
        {
            "sft_steps": SFT_STEPS,
            "grpo_steps": GRPO_STEPS,
            "rewards_logged": len(training_rewards),
            "model_repo": MODEL_REPO,
        },
        indent=2,
    ),
    flush=True,
)