akhiilll commited on
Commit
ceda979
·
verified ·
1 Parent(s): cdeb5bc

forgeenv source snapshot for training job

Browse files
forgeenv/training/grpo_repair.py CHANGED
@@ -164,7 +164,6 @@ def run_grpo(
164
  learning_rate=learning_rate,
165
  max_steps=total_episodes,
166
  num_generations=group_size,
167
- max_prompt_length=2048,
168
  max_completion_length=1024,
169
  logging_steps=5,
170
  save_steps=max(50, total_episodes // 4),
@@ -175,7 +174,7 @@ def run_grpo(
175
  )
176
  trainer = GRPOTrainer(
177
  model=model,
178
- tokenizer=tokenizer,
179
  args=grpo_config,
180
  train_dataset=dataset,
181
  reward_funcs=[reward_repair_function],
 
164
  learning_rate=learning_rate,
165
  max_steps=total_episodes,
166
  num_generations=group_size,
 
167
  max_completion_length=1024,
168
  logging_steps=5,
169
  save_steps=max(50, total_episodes // 4),
 
174
  )
175
  trainer = GRPOTrainer(
176
  model=model,
177
+ processing_class=tokenizer,
178
  args=grpo_config,
179
  train_dataset=dataset,
180
  reward_funcs=[reward_repair_function],
scripts/jobs/train_repair_agent.py CHANGED
@@ -68,23 +68,31 @@ _sh([
68
  # venv. We still run pip install for any setuptools side-effects.
69
  sys.path.insert(0, str(src_dir))
70
 
71
- step("1. pip install (no torch/transformers churn) + verify GPU")
72
- # IMPORTANT: do NOT run `pip install -e .[openenv]` it transitively
73
- # downgrades torch via openenv-core, which breaks CUDA on H100/H200.
74
- # We rely on sys.path (set above) for `import forgeenv`.
 
 
 
 
 
 
 
75
  _sh([
76
  sys.executable, "-m", "pip", "install", "--no-deps",
77
  "openenv-core>=0.2.0",
78
  ])
79
  _sh([
80
  sys.executable, "-m", "pip", "install",
81
- "trl", "peft", "accelerate", "datasets", "bitsandbytes",
 
82
  "matplotlib", "pyyaml", "nltk", "scikit-learn",
83
  "fastapi", "uvicorn", "pydantic", "requests",
 
84
  ])
85
  try:
86
- # --no-deps is critical: prevents unsloth from pulling in a CPU-only
87
- # torch wheel that overwrites the uv image's GPU torch.
88
  _sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"])
89
  except subprocess.CalledProcessError:
90
  print("[job] WARN: unsloth install failed — trainer will use plain HF.", flush=True)
@@ -171,7 +179,7 @@ sft_ds = sft_ds.map(_format_chat, remove_columns=sft_ds.column_names)
171
 
172
  sft_trainer = SFTTrainer(
173
  model=model,
174
- tokenizer=tokenizer,
175
  train_dataset=sft_ds,
176
  args=SFTConfig(
177
  output_dir=str(SFT_DIR),
@@ -183,7 +191,7 @@ sft_trainer = SFTTrainer(
183
  save_steps=max(250, SFT_STEPS // 4),
184
  bf16=torch.cuda.is_bf16_supported(),
185
  fp16=not torch.cuda.is_bf16_supported(),
186
- max_seq_length=2048,
187
  report_to=[],
188
  ),
189
  )
 
68
  # venv. We still run pip install for any setuptools side-effects.
69
  sys.path.insert(0, str(src_dir))
70
 
71
+ step("1. pin torch (cu124) + install GPU-stable deps")
72
+ # Force a CUDA 12.4 torch wheel BEFORE anything else so other packages'
73
+ # resolvers don't pull a cu130 wheel that mismatches the host driver
74
+ # (this is what causes "Error 802: system not yet initialized" on H200).
75
+ _sh([
76
+ sys.executable, "-m", "pip", "install",
77
+ "--index-url", "https://download.pytorch.org/whl/cu124",
78
+ "torch==2.5.1", "torchvision==0.20.1",
79
+ ])
80
+ # `--no-deps` on openenv-core: it pins a different transformers/torch
81
+ # stack that we don't want.
82
  _sh([
83
  sys.executable, "-m", "pip", "install", "--no-deps",
84
  "openenv-core>=0.2.0",
85
  ])
86
  _sh([
87
  sys.executable, "-m", "pip", "install",
88
+ "trl==1.2.0", "peft", "accelerate", "datasets",
89
+ "bitsandbytes",
90
  "matplotlib", "pyyaml", "nltk", "scikit-learn",
91
  "fastapi", "uvicorn", "pydantic", "requests",
92
+ "sentencepiece", "protobuf",
93
  ])
94
  try:
95
+ # --no-deps is critical: prevents unsloth from re-resolving torch.
 
96
  _sh([sys.executable, "-m", "pip", "install", "--no-deps", "unsloth", "unsloth-zoo"])
97
  except subprocess.CalledProcessError:
98
  print("[job] WARN: unsloth install failed — trainer will use plain HF.", flush=True)
 
179
 
180
  sft_trainer = SFTTrainer(
181
  model=model,
182
+ processing_class=tokenizer,
183
  train_dataset=sft_ds,
184
  args=SFTConfig(
185
  output_dir=str(SFT_DIR),
 
191
  save_steps=max(250, SFT_STEPS // 4),
192
  bf16=torch.cuda.is_bf16_supported(),
193
  fp16=not torch.cuda.is_bf16_supported(),
194
+ max_length=2048,
195
  report_to=[],
196
  ),
197
  )
scripts/preflight_check.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Local preflight: validate every component the H200 training job touches
3
+ WITHOUT spending GPU time. Each test prints PASS/FAIL with a short reason.
4
+
5
+ Run::
6
+
7
+ python scripts/preflight_check.py
8
+
9
+ The script exits non-zero if any required test fails. Optional tests
10
+ (network/Hub) print SKIP if HF_TOKEN is not set or the env Space is down.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import os
16
+ import sys
17
+ import tempfile
18
+ import traceback
19
+ from pathlib import Path
20
+ from typing import Callable
21
+
22
+ REPO_ROOT = Path(__file__).resolve().parents[1]
23
+ sys.path.insert(0, str(REPO_ROOT))
24
+
25
+ PASS = "[PASS]"
26
+ FAIL = "[FAIL]"
27
+ SKIP = "[SKIP]"
28
+
29
+ _results: list[tuple[str, str, str]] = []
30
+
31
+
32
+ def _run(label: str, fn: Callable[[], str | None], required: bool = True) -> None:
33
+ try:
34
+ detail = fn() or ""
35
+ _results.append((PASS, label, detail))
36
+ print(f"{PASS} {label} {detail}", flush=True)
37
+ except _Skip as s:
38
+ _results.append((SKIP, label, str(s)))
39
+ print(f"{SKIP} {label} {s}", flush=True)
40
+ except Exception as e: # noqa: BLE001
41
+ tag = FAIL if required else SKIP
42
+ _results.append((tag, label, f"{type(e).__name__}: {e}"))
43
+ print(f"{tag} {label} {type(e).__name__}: {e}", flush=True)
44
+ if required:
45
+ traceback.print_exc()
46
+
47
+
48
+ class _Skip(Exception):
49
+ pass
50
+
51
+
52
+ def t1_imports() -> str:
53
+ import forgeenv # noqa: F401
54
+ import trl # noqa: F401
55
+ import peft # noqa: F401
56
+ import datasets # noqa: F401
57
+ import transformers # noqa: F401
58
+ import accelerate # noqa: F401
59
+
60
+ from forgeenv.training.grpo_repair import ( # noqa: F401
61
+ run_grpo,
62
+ reward_repair_function,
63
+ )
64
+ from forgeenv.training.plots import ( # noqa: F401
65
+ plot_baseline_vs_trained,
66
+ plot_reward_curve,
67
+ plot_success_rate_by_category,
68
+ )
69
+ from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction # noqa: F401
70
+ from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff # noqa: F401
71
+ from forgeenv.env.forge_environment import ForgeEnvironment # noqa: F401
72
+ from forgeenv.roles.repair_agent import extract_diff # noqa: F401
73
+ from forgeenv.tasks.task_sampler import TaskSampler # noqa: F401
74
+
75
+ return f"trl={trl.__version__} transformers={transformers.__version__}"
76
+
77
+
78
+ def t2_dataset_load_and_format() -> str:
79
+ import datasets as ds
80
+
81
+ p = REPO_ROOT / "warmstart" / "data" / "repair_pairs.jsonl"
82
+ if not p.exists():
83
+ raise FileNotFoundError(p)
84
+ sft_ds = ds.load_dataset("json", data_files=str(p), split="train")
85
+ n = len(sft_ds)
86
+ if n < 10:
87
+ raise ValueError(f"too few rows in repair_pairs.jsonl: {n}")
88
+ row = sft_ds[0]
89
+ if "messages" not in row or not row["messages"]:
90
+ raise KeyError("row missing 'messages' field")
91
+ roles = {m["role"] for m in row["messages"]}
92
+ if not {"system", "user", "assistant"}.issubset(roles):
93
+ raise ValueError(f"unexpected role set: {roles}")
94
+ return f"rows={n} roles={sorted(roles)}"
95
+
96
+
97
+ def t3_trl_configs_accept_our_kwargs() -> str:
98
+ """Validate every kwarg name the job passes is accepted by the
99
+ current TRL Config classes. We inspect dataclass fields directly so
100
+ this works on CPU-only Windows without tripping bf16/use_cpu
101
+ validation in transformers' TrainingArguments.__post_init__."""
102
+ import dataclasses
103
+
104
+ from trl import GRPOConfig, SFTConfig
105
+
106
+ sft_kwargs = {
107
+ "output_dir": "/tmp/forge_sft",
108
+ "max_steps": 10,
109
+ "per_device_train_batch_size": 4,
110
+ "gradient_accumulation_steps": 4,
111
+ "learning_rate": 2e-4,
112
+ "logging_steps": 25,
113
+ "save_steps": 250,
114
+ "bf16": True,
115
+ "fp16": False,
116
+ "max_length": 2048,
117
+ "report_to": [],
118
+ }
119
+ grpo_kwargs = {
120
+ "output_dir": "/tmp/forge_grpo",
121
+ "per_device_train_batch_size": 1,
122
+ "gradient_accumulation_steps": 4,
123
+ "learning_rate": 5e-6,
124
+ "max_steps": 5,
125
+ "num_generations": 4,
126
+ "max_completion_length": 1024,
127
+ "logging_steps": 5,
128
+ "save_steps": 50,
129
+ "save_total_limit": 2,
130
+ "seed": 0,
131
+ "report_to": "none",
132
+ "beta": 0.04,
133
+ }
134
+
135
+ def _field_names(cls) -> set[str]:
136
+ names: set[str] = set()
137
+ for c in cls.__mro__:
138
+ if dataclasses.is_dataclass(c):
139
+ names.update(f.name for f in dataclasses.fields(c))
140
+ return names
141
+
142
+ sft_fields = _field_names(SFTConfig)
143
+ missing_sft = [k for k in sft_kwargs if k not in sft_fields]
144
+ if missing_sft:
145
+ raise TypeError(f"SFTConfig missing fields: {missing_sft}")
146
+
147
+ grpo_fields = _field_names(GRPOConfig)
148
+ missing_grpo = [k for k in grpo_kwargs if k not in grpo_fields]
149
+ if missing_grpo:
150
+ raise TypeError(f"GRPOConfig missing fields: {missing_grpo}")
151
+
152
+ # Best-effort: try actually instantiating with use_cpu=True so even
153
+ # __post_init__ runs cleanly under our preflight conditions.
154
+ try:
155
+ SFTConfig(**sft_kwargs, use_cpu=True, bf16=False)
156
+ GRPOConfig(**grpo_kwargs, use_cpu=True)
157
+ instantiated = "instantiated OK"
158
+ except Exception as e: # noqa: BLE001
159
+ instantiated = f"field-check OK; instantiation skipped ({type(e).__name__})"
160
+
161
+ return (
162
+ f"SFT/GRPO kwargs all valid; sft_fields={len(sft_fields)} "
163
+ f"grpo_fields={len(grpo_fields)}; {instantiated}"
164
+ )
165
+
166
+
167
+ def t4_reward_function_returns_float() -> str:
168
+ from forgeenv.training.grpo_repair import reward_repair_function
169
+ from forgeenv.tasks.task_sampler import TaskSampler
170
+
171
+ sampler = TaskSampler()
172
+ if not sampler.tasks:
173
+ raise RuntimeError("TaskSampler has no tasks")
174
+ task_id = sampler.tasks[0].task_id
175
+ broken = "x = 1\nprint(x)\n"
176
+ fake_completion = (
177
+ "--- a/train.py\n"
178
+ "+++ b/train.py\n"
179
+ "@@ -1,2 +1,2 @@\n"
180
+ "-x = 1\n"
181
+ "+x = 2\n"
182
+ " print(x)\n"
183
+ )
184
+ rewards = reward_repair_function(
185
+ completions=[fake_completion],
186
+ prompts=[[]],
187
+ task_id=[task_id],
188
+ broken_script=[broken],
189
+ )
190
+ if len(rewards) != 1:
191
+ raise ValueError(f"expected 1 reward got {len(rewards)}")
192
+ if not isinstance(rewards[0], float):
193
+ raise TypeError(f"reward not float: {type(rewards[0])}")
194
+ return f"reward={rewards[0]:.3f} (single fake completion)"
195
+
196
+
197
+ def t5_diff_utils_roundtrip() -> str:
198
+ from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
199
+ from forgeenv.roles.repair_agent import extract_diff
200
+
201
+ a = "x = 1\nprint(x)\n"
202
+ b = "x = 2\nprint(x)\n"
203
+ d = make_unified_diff(a, b)
204
+ if not d.strip():
205
+ raise ValueError("make_unified_diff returned empty")
206
+ blob = "Some thinking...\n```diff\n" + d + "\n```\nmore prose"
207
+ extracted = extract_diff(blob)
208
+ if not extracted.strip():
209
+ raise ValueError("extract_diff failed to find diff in fenced block")
210
+ repaired = apply_unified_diff(a, extracted)
211
+ if "x = 2" not in repaired:
212
+ raise ValueError(f"apply_unified_diff failed: {repaired!r}")
213
+ return f"diff_len={len(d)} extract+apply OK"
214
+
215
+
216
+ def t6_live_env_health() -> str:
217
+ import requests
218
+
219
+ user = os.environ.get("HF_USERNAME", "akhiilll")
220
+ url = f"https://{user}-forgeenv.hf.space/health"
221
+ try:
222
+ r = requests.get(url, timeout=15)
223
+ except Exception as e: # noqa: BLE001
224
+ raise _Skip(f"network: {e}")
225
+ if r.status_code >= 400:
226
+ raise RuntimeError(f"{url} -> {r.status_code} {r.text[:80]}")
227
+ return f"{r.status_code} {r.text[:60]!r}"
228
+
229
+
230
+ def t7_source_repo_exists() -> str:
231
+ token = os.environ.get("HF_TOKEN")
232
+ if not token:
233
+ raise _Skip("HF_TOKEN not set")
234
+ from huggingface_hub import HfApi
235
+
236
+ api = HfApi()
237
+ user = os.environ.get("HF_USERNAME", "akhiilll")
238
+ repo_id = f"{user}/forgeenv-source"
239
+ files = api.list_repo_files(repo_id=repo_id, repo_type="model", token=token)
240
+ needed = "scripts/jobs/train_repair_agent.py"
241
+ if needed not in files:
242
+ raise FileNotFoundError(f"{needed} missing from {repo_id} (files: {len(files)})")
243
+ return f"{repo_id} has {len(files)} files incl. train_repair_agent.py"
244
+
245
+
246
+ def t8_qwen_tokenizer_loads() -> str:
247
+ base = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-3B-Instruct")
248
+ token = os.environ.get("HF_TOKEN")
249
+ from transformers import AutoTokenizer
250
+
251
+ tok = AutoTokenizer.from_pretrained(base, token=token, trust_remote_code=False)
252
+ msgs = [
253
+ {"role": "system", "content": "you are a repair agent"},
254
+ {"role": "user", "content": "fix this"},
255
+ {"role": "assistant", "content": "--- a/train.py\n+++ b/train.py\n"},
256
+ ]
257
+ text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
258
+ if "<|im_start|>" not in text:
259
+ raise ValueError("ChatML tokens missing from rendered template")
260
+ if "fix this" not in text:
261
+ raise ValueError("user content not in rendered template")
262
+ return f"{base} chat_template renders ChatML ({len(text)} chars)"
263
+
264
+
265
+ def t9_hfapi_auth_and_namespace() -> str:
266
+ token = os.environ.get("HF_TOKEN")
267
+ if not token:
268
+ raise _Skip("HF_TOKEN not set")
269
+ from huggingface_hub import HfApi
270
+
271
+ api = HfApi()
272
+ info = api.whoami(token=token)
273
+ user = info.get("name") or info.get("fullname")
274
+ if not user:
275
+ raise RuntimeError(f"whoami returned no name: {info}")
276
+ expected = os.environ.get("HF_USERNAME", "akhiilll")
277
+ if user != expected:
278
+ return f"WARN: token user={user} but HF_USERNAME={expected}"
279
+ return f"authed as {user}"
280
+
281
+
282
+ def t10_find_trainer_state() -> str:
283
+ sys.path.insert(0, str(REPO_ROOT / "scripts" / "jobs"))
284
+ with tempfile.TemporaryDirectory() as td:
285
+ td_p = Path(td)
286
+ ckpt = td_p / "checkpoint-80"
287
+ ckpt.mkdir()
288
+ state = {
289
+ "log_history": [
290
+ {"step": 5, "rewards/reward_repair_function/mean": 0.12},
291
+ {"step": 10, "rewards/reward_repair_function/mean": 0.34},
292
+ ]
293
+ }
294
+ (ckpt / "trainer_state.json").write_text(json.dumps(state))
295
+ from importlib import util as _util
296
+
297
+ spec = _util.spec_from_file_location(
298
+ "_train_mod", REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py"
299
+ )
300
+ if spec is None or spec.loader is None:
301
+ raise RuntimeError("can't spec the training script")
302
+ # Don't actually load the module (it has top-level CUDA/HF effects).
303
+ # Re-implement the same finder here from source.
304
+ # The script uses: prefer GRPO_DIR/trainer_state.json, else newest checkpoint-*.
305
+ direct = td_p / "trainer_state.json"
306
+ if direct.exists():
307
+ found = direct
308
+ else:
309
+ ckpts = sorted(
310
+ (p for p in td_p.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
311
+ key=lambda p: int(p.name.split("-")[-1]),
312
+ )
313
+ found = (ckpts[-1] / "trainer_state.json") if ckpts else None
314
+ if found is None or not found.exists():
315
+ raise RuntimeError("finder did not locate the synthesized state")
316
+ loaded = json.loads(found.read_text())
317
+ if len(loaded["log_history"]) != 2:
318
+ raise RuntimeError("finder loaded wrong file")
319
+ return "checkpoint-N/trainer_state.json discoverable"
320
+
321
+
322
+ def main() -> int:
323
+ print(f"\n=== ForgeEnv preflight (repo: {REPO_ROOT}) ===\n", flush=True)
324
+ _run("01 imports", t1_imports, required=True)
325
+ _run("02 dataset load + format", t2_dataset_load_and_format, required=True)
326
+ _run("03 TRL configs (SFT/GRPO) accept kwargs", t3_trl_configs_accept_our_kwargs, required=True)
327
+ _run("04 reward fn returns float", t4_reward_function_returns_float, required=True)
328
+ _run("05 diff utils round-trip", t5_diff_utils_roundtrip, required=True)
329
+ _run("06 live env /health", t6_live_env_health, required=False)
330
+ _run("07 forgeenv-source repo on Hub", t7_source_repo_exists, required=False)
331
+ _run("08 Qwen tokenizer + ChatML", t8_qwen_tokenizer_loads, required=True)
332
+ _run("09 HfApi auth", t9_hfapi_auth_and_namespace, required=False)
333
+ _run("10 _find_trainer_state logic", t10_find_trainer_state, required=True)
334
+
335
+ print("\n=== Summary ===")
336
+ n_pass = sum(1 for r in _results if r[0] == PASS)
337
+ n_fail = sum(1 for r in _results if r[0] == FAIL)
338
+ n_skip = sum(1 for r in _results if r[0] == SKIP)
339
+ for tag, label, detail in _results:
340
+ print(f"{tag} {label}")
341
+ print(f"\n{n_pass} passed, {n_fail} failed, {n_skip} skipped")
342
+ return 0 if n_fail == 0 else 1
343
+
344
+
345
+ if __name__ == "__main__":
346
+ sys.exit(main())