akhiilll commited on
Commit
2cf3915
·
verified ·
1 Parent(s): 9c7f0da

forgeenv source snapshot for training job

Browse files
demo-space/app.py CHANGED
@@ -1,22 +1,32 @@
1
  """Gradio demo Space for the ForgeEnv Repair Agent.
2
 
3
- Loads the trained LoRA adapter from the Hub and exposes a 2-input form:
4
- broken script + error trace. Output is a unified diff. Inference runs on
5
- ZeroGPU (`@spaces.GPU`) so we don't pay for idle GPU time.
6
 
7
- If the trained adapter isn't yet uploaded, the demo falls back to the
8
- deterministic ``BaselineRepairAgent`` so the Space still works end-to-end.
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
  from __future__ import annotations
11
 
12
  import json
13
  import os
 
14
  import traceback
15
  from typing import Optional
16
 
17
  import gradio as gr
18
 
19
- BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
20
  ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent")
21
 
22
  _TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift"
@@ -25,7 +35,9 @@ _DESCRIPTION = (
25
  "produced. The Repair Agent returns a minimal unified diff. The model "
26
  "was trained inside [ForgeEnv](https://huggingface.co/spaces/"
27
  "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style "
28
- "Challenger / Solver co-evolution."
 
 
29
  )
30
 
31
  _EXAMPLES = [
@@ -80,6 +92,29 @@ _tokenizer = None
80
  _load_error: Optional[str] = None
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def _load_model() -> None:
84
  """Lazy-load the trained LoRA on first GPU invocation."""
85
  global _model, _tokenizer, _load_error
@@ -96,10 +131,18 @@ def _load_model() -> None:
96
  torch_dtype=torch.float16,
97
  device_map="auto",
98
  )
99
- try:
100
- model = PeftModel.from_pretrained(base, ADAPTER_REPO)
101
- except Exception as e: # noqa: BLE001
102
- print(f"[demo] adapter not found ({e}); using base model")
 
 
 
 
 
 
 
 
103
  model = base
104
  _model = model.eval()
105
  _tokenizer = tokenizer
@@ -107,36 +150,34 @@ def _load_model() -> None:
107
  _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
108
 
109
 
110
- def _baseline_fallback(script: str, error_trace: str) -> str:
111
- """Deterministic repair if the trained model isn't available.
112
-
113
- Uses the in-repo BaselineRepairAgent if the package is installed; else
114
- just returns an explanatory message.
115
- """
116
- try:
117
- from forgeenv.roles.repair_agent import BaselineRepairAgent
118
-
119
- agent = BaselineRepairAgent()
120
- return agent.repair(script, breakage_spec=None, original_script=None)
121
- except Exception: # noqa: BLE001
122
- return (
123
- "# (Fallback) Trained adapter unavailable in this Space.\n"
124
- "# Likely fix based on the error trace:\n"
125
- f"# {error_trace.splitlines()[0] if error_trace else ''}\n"
126
- )
127
 
128
 
129
- def _generate_with_model(prompt: str, max_new_tokens: int = 512) -> str:
 
130
  import torch
131
 
132
- inputs = _tokenizer(prompt, return_tensors="pt").to(_model.device)
 
 
 
 
 
 
 
 
 
 
133
  with torch.no_grad():
134
  out = _model.generate(
135
  **inputs,
136
  max_new_tokens=max_new_tokens,
137
- do_sample=True,
138
- temperature=0.3,
139
- top_p=0.9,
140
  pad_token_id=_tokenizer.eos_token_id,
141
  )
142
  completion = _tokenizer.decode(
@@ -145,8 +186,160 @@ def _generate_with_model(prompt: str, max_new_tokens: int = 512) -> str:
145
  return completion.strip()
146
 
147
 
148
- # Wrap inference in a `@spaces.GPU` decorator if available so we get a free
149
- # ZeroGPU slice. Outside ZeroGPU it's a no-op.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  try:
151
  import spaces # type: ignore
152
 
@@ -161,22 +354,66 @@ def repair_script(script: str, error_trace: str) -> str:
161
  if not script.strip():
162
  return "# Paste a broken script first."
163
 
 
 
 
 
 
164
  _load_model()
165
- if _model is None:
166
- return _baseline_fallback(script, error_trace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- versions = json.dumps(
169
- {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"}
170
- )
171
- prompt = _PROMPT_TEMPLATE.format(
172
- versions=versions, script=script, trace=error_trace or "(no trace)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
- try:
175
- return _generate_with_model(prompt)
176
- except Exception as e: # noqa: BLE001
177
- return f"# generation failed: {e}\n" + _baseline_fallback(script, error_trace)
178
 
179
 
 
180
  with gr.Blocks(title="ForgeEnv Repair Agent") as demo:
181
  gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}")
182
  with gr.Row():
 
1
  """Gradio demo Space for the ForgeEnv Repair Agent.
2
 
3
+ Three-tier repair pipeline so the demo always returns a useful diff:
 
 
4
 
5
+ 1. **Trained LoRA model** Qwen 2.5 + ForgeEnv GRPO adapter. If the model
6
+ emits a diff that, when applied, actually changes the broken script,
7
+ we use it.
8
+ 2. **Error-trace heuristic** — extracts the fix signal from the Python
9
+ traceback (Did you mean / unexpected kwarg / No module named) and
10
+ emits a clean canonical diff. Handles the most common drift patterns.
11
+ 3. **Model reasoning hint** — if heuristic fails, surface the model's
12
+ natural-language reasoning (it usually explains the bug correctly even
13
+ when its diff syntax is broken) alongside a "no patch produced" note.
14
+
15
+ This separation means the demo is robust regardless of how well the
16
+ LoRA generalises on a given input — and it's honest about what each
17
+ component contributed.
18
  """
19
  from __future__ import annotations
20
 
21
  import json
22
  import os
23
+ import re
24
  import traceback
25
  from typing import Optional
26
 
27
  import gradio as gr
28
 
29
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
30
  ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent")
31
 
32
  _TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift"
 
35
  "produced. The Repair Agent returns a minimal unified diff. The model "
36
  "was trained inside [ForgeEnv](https://huggingface.co/spaces/"
37
  "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style "
38
+ "Challenger / Solver co-evolution. The agent is backed by a heuristic "
39
+ "fallback that parses error traces directly when the LoRA's diff is "
40
+ "malformed — keeps the demo robust on out-of-distribution inputs."
41
  )
42
 
43
  _EXAMPLES = [
 
92
  _load_error: Optional[str] = None
93
 
94
 
95
+ # ----------------------------------------------------------------- model io
96
+ def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool:
97
+ """Cheap pre-check: pull adapter_config.json and compare base_model_name."""
98
+ try:
99
+ from huggingface_hub import hf_hub_download
100
+
101
+ cfg_path = hf_hub_download(
102
+ repo_id=adapter_repo,
103
+ filename="adapter_config.json",
104
+ token=os.environ.get("HF_TOKEN"),
105
+ )
106
+ with open(cfg_path) as f:
107
+ cfg = json.load(f)
108
+ adapter_base = (cfg.get("base_model_name_or_path") or "").lower()
109
+ # Match by family substring -- "qwen2.5-coder-7b" must be present in
110
+ # the base name, otherwise the adapter targets a different arch.
111
+ family = base_name.split("/")[-1].lower().replace("-instruct", "")
112
+ return family in adapter_base
113
+ except Exception as e: # noqa: BLE001
114
+ print(f"[demo] adapter_config check failed ({e}); attempting load anyway")
115
+ return True
116
+
117
+
118
  def _load_model() -> None:
119
  """Lazy-load the trained LoRA on first GPU invocation."""
120
  global _model, _tokenizer, _load_error
 
131
  torch_dtype=torch.float16,
132
  device_map="auto",
133
  )
134
+ if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL):
135
+ try:
136
+ model = PeftModel.from_pretrained(base, ADAPTER_REPO)
137
+ print(f"[demo] LoRA attached: {ADAPTER_REPO}")
138
+ except Exception as e: # noqa: BLE001
139
+ print(f"[demo] adapter load failed ({e}); using base model")
140
+ model = base
141
+ else:
142
+ print(
143
+ f"[demo] adapter at {ADAPTER_REPO} was trained on a different "
144
+ f"base; using {BASE_MODEL} alone until matching adapter ships"
145
+ )
146
  model = base
147
  _model = model.eval()
148
  _tokenizer = tokenizer
 
150
  _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
151
 
152
 
153
+ _SYSTEM_PROMPT = (
154
+ "You are an expert ML engineer who fixes broken HuggingFace training "
155
+ "scripts caused by library version drift. Output ONLY a unified diff."
156
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
+ def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str:
160
+ """Greedy decode using the base model's chat template (Qwen ChatML)."""
161
  import torch
162
 
163
+ messages = [
164
+ {"role": "system", "content": _SYSTEM_PROMPT},
165
+ {"role": "user", "content": prompt},
166
+ ]
167
+ try:
168
+ text = _tokenizer.apply_chat_template(
169
+ messages, tokenize=False, add_generation_prompt=True
170
+ )
171
+ except Exception: # noqa: BLE001
172
+ text = prompt
173
+ inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
174
  with torch.no_grad():
175
  out = _model.generate(
176
  **inputs,
177
  max_new_tokens=max_new_tokens,
178
+ do_sample=False,
179
+ temperature=0.0,
180
+ repetition_penalty=1.15,
181
  pad_token_id=_tokenizer.eos_token_id,
182
  )
183
  completion = _tokenizer.decode(
 
186
  return completion.strip()
187
 
188
 
189
+ # -------------------------------------------------------- diff extraction
190
+ _FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE)
191
+ _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
192
+
193
+
194
+ def _extract_diff_block(raw: str) -> str:
195
+ """Pull the *first* fenced diff out of the model's raw output."""
196
+ if not raw:
197
+ return ""
198
+ m = _FENCE_RE.search(raw)
199
+ if m:
200
+ return m.group(1).strip()
201
+ # otherwise grab from the first '---' / '+++' / '@@' onwards
202
+ for marker in ("--- ", "+++ ", "@@"):
203
+ idx = raw.find(marker)
204
+ if idx >= 0:
205
+ return raw[idx:].strip()
206
+ return ""
207
+
208
+
209
+ def _diff_actually_changes_script(broken: str, diff_text: str) -> bool:
210
+ """Try to apply the diff. Returns True iff the result differs from input."""
211
+ if not diff_text:
212
+ return False
213
+ try:
214
+ from forgeenv.env.diff_utils import apply_unified_diff
215
+
216
+ repaired = apply_unified_diff(broken, diff_text)
217
+ return bool(repaired) and repaired.strip() != broken.strip()
218
+ except Exception: # noqa: BLE001
219
+ return False
220
+
221
+
222
+ def _canonicalise(broken: str, diff_text: str) -> str:
223
+ """Apply diff -> rebuild a clean canonical unified diff."""
224
+ from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
225
+
226
+ repaired = apply_unified_diff(broken, diff_text)
227
+ if not repaired or repaired.strip() == broken.strip():
228
+ return ""
229
+ return make_unified_diff(broken, repaired)
230
+
231
+
232
+ def _extract_model_reasoning(raw: str) -> str:
233
+ """Pull the natural-language reasoning out of the model's output (if any)."""
234
+ if not raw:
235
+ return ""
236
+ text = re.sub(_FENCE_RE, "", raw).strip()
237
+ text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip()
238
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
239
+ sentences: list[str] = []
240
+ for ln in lines:
241
+ if ln.startswith(("---", "+++", "@@", "-", "+")):
242
+ continue
243
+ if len(ln) < 10:
244
+ continue
245
+ sentences.append(ln)
246
+ if len(sentences) >= 3:
247
+ break
248
+ return " ".join(sentences)
249
+
250
+
251
+ # ---------------------------------------------------- error-trace heuristic
252
+ _DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE)
253
+ _NO_ATTR_RE = re.compile(
254
+ r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE
255
+ )
256
+ _NO_MODULE_RE = re.compile(
257
+ r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE
258
+ )
259
+ _BAD_KWARG_RE = re.compile(
260
+ r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE
261
+ )
262
+ _USE_INSTEAD_RE = re.compile(
263
+ r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE
264
+ )
265
+
266
+
267
+ def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]:
268
+ """Produce a (repaired_script, fix_description) pair from the trace.
269
+
270
+ Patterns covered:
271
+ * AttributeError + "Did you mean: 'X'?" -> rename method
272
+ * AttributeError without hint -> remove the call (rarely useful)
273
+ * ModuleNotFoundError 'X.Y' -> drop the .Y submodule
274
+ * TypeError unexpected kwarg + 'use Y' -> swap kwarg
275
+ * TypeError unexpected kwarg, no hint -> drop the kwarg
276
+ """
277
+ if not error_trace:
278
+ return broken, ""
279
+ trace = error_trace.strip()
280
+ repaired = broken
281
+ description = ""
282
+
283
+ # 1. AttributeError 'X' + Did you mean 'Y'
284
+ if "AttributeError" in trace or "has no attribute" in trace:
285
+ old = _NO_ATTR_RE.search(trace)
286
+ new = _DID_YOU_MEAN_RE.search(trace)
287
+ if old and new and old.group(1) != new.group(1):
288
+ old_name, new_name = old.group(1), new.group(1)
289
+ pattern = re.compile(rf"\b{re.escape(old_name)}\b")
290
+ if pattern.search(repaired):
291
+ repaired = pattern.sub(new_name, repaired)
292
+ description = (
293
+ f"`{old_name}` is no longer an attribute on this object; "
294
+ f"renamed call to `{new_name}` per the traceback hint."
295
+ )
296
+
297
+ # 2. ModuleNotFoundError 'X.Y' (or 'X')
298
+ if not description and "No module named" in trace:
299
+ m = _NO_MODULE_RE.search(trace)
300
+ if m:
301
+ mod = m.group(1)
302
+ if "." in mod:
303
+ parent, child = mod.rsplit(".", 1)
304
+ pat_full = re.compile(rf"\b{re.escape(mod)}\b")
305
+ if pat_full.search(repaired):
306
+ repaired = pat_full.sub(parent, repaired)
307
+ description = (
308
+ f"`{mod}` was removed; replaced with parent module "
309
+ f"`{parent}`."
310
+ )
311
+
312
+ # 3. TypeError unexpected kwarg
313
+ if not description and "unexpected keyword argument" in trace:
314
+ bad = _BAD_KWARG_RE.search(trace)
315
+ good = _USE_INSTEAD_RE.search(trace)
316
+ if bad:
317
+ bad_kw = bad.group(1)
318
+ if good:
319
+ good_kw = good.group(1)
320
+ pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=")
321
+ if pat.search(repaired):
322
+ repaired = pat.sub(f"{good_kw}=", repaired)
323
+ # if old kwarg was a boolean-ish, also swap the value
324
+ # (pad_to_max_length=True -> padding=True is fine)
325
+ description = (
326
+ f"`{bad_kw}` was renamed to `{good_kw}`; updated "
327
+ f"keyword to match the new API."
328
+ )
329
+ else:
330
+ # remove the kwarg entirely (best-effort)
331
+ pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+")
332
+ if pat.search(repaired):
333
+ repaired = pat.sub("", repaired)
334
+ description = (
335
+ f"`{bad_kw}` is no longer accepted; removed the "
336
+ f"keyword argument."
337
+ )
338
+
339
+ return repaired, description
340
+
341
+
342
+ # ------------------------------------------------------------- entry point
343
  try:
344
  import spaces # type: ignore
345
 
 
354
  if not script.strip():
355
  return "# Paste a broken script first."
356
 
357
+ # Tier 1: trained LoRA
358
+ model_raw = ""
359
+ model_diff_canonical = ""
360
+ model_reasoning = ""
361
+
362
  _load_model()
363
+ if _model is not None:
364
+ try:
365
+ versions = json.dumps(
366
+ {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"}
367
+ )
368
+ prompt = _PROMPT_TEMPLATE.format(
369
+ versions=versions,
370
+ script=script,
371
+ trace=error_trace or "(no trace)",
372
+ )
373
+ model_raw = _generate_with_model(prompt)
374
+ model_diff_text = _extract_diff_block(model_raw)
375
+ if _diff_actually_changes_script(script, model_diff_text):
376
+ model_diff_canonical = _canonicalise(script, model_diff_text)
377
+ model_reasoning = _extract_model_reasoning(model_raw)
378
+ except Exception as e: # noqa: BLE001
379
+ print(f"[demo] model generation failed: {e}")
380
 
381
+ if model_diff_canonical:
382
+ header = (
383
+ "# Source: trained LoRA (ForgeEnv GRPO adapter)\n"
384
+ "# The model produced a valid diff that successfully patches the script.\n"
385
+ )
386
+ return header + "\n" + model_diff_canonical
387
+
388
+ # Tier 2: error-trace heuristic
389
+ repaired, description = _heuristic_repair(script, error_trace)
390
+ if description and repaired != script:
391
+ from forgeenv.env.diff_utils import make_unified_diff
392
+
393
+ diff = make_unified_diff(script, repaired)
394
+ header_lines = [
395
+ "# Source: error-trace heuristic (LoRA diff was malformed; "
396
+ "fell back to deterministic repair).",
397
+ f"# Fix: {description}",
398
+ ]
399
+ if model_reasoning:
400
+ header_lines.append(f"# Trained model said: {model_reasoning}")
401
+ return "\n".join(header_lines) + "\n\n" + diff
402
+
403
+ # Tier 3: nothing worked -- surface what we know
404
+ msg_lines = ["# Could not produce a confident patch."]
405
+ if model_reasoning:
406
+ msg_lines.append(f"# Trained model reasoning: {model_reasoning}")
407
+ if error_trace:
408
+ msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}")
409
+ msg_lines.append(
410
+ "# Try a more specific error trace (the heuristic looks for "
411
+ "'Did you mean', 'No module named', or 'unexpected keyword argument')."
412
  )
413
+ return "\n".join(msg_lines)
 
 
 
414
 
415
 
416
+ # ----------------------------------------------------------------- gradio
417
  with gr.Blocks(title="ForgeEnv Repair Agent") as demo:
418
  gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}")
419
  with gr.Row():
demo-space/test_heuristic.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick local sanity check for the heuristic repair fallback.
2
+
3
+ Run with::
4
+
5
+ python demo-space/test_heuristic.py
6
+
7
+ Each case must produce a non-empty fix description and a script that
8
+ differs from the input.
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ REPO = Path(__file__).resolve().parent.parent
16
+ sys.path.insert(0, str(REPO))
17
+ sys.path.insert(0, str(REPO / "demo-space"))
18
+
19
+ from app import _heuristic_repair # noqa: E402
20
+
21
+ CASES = [
22
+ {
23
+ "name": "AttributeError + Did you mean",
24
+ "script": (
25
+ "from transformers import Trainer, TrainingArguments\n"
26
+ "from datasets import load_dataset\n\n"
27
+ "ds = load_dataset('glue', 'sst2')\n"
28
+ "args = TrainingArguments(output_dir='out')\n"
29
+ "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
30
+ "trainer.start_training()\n"
31
+ ),
32
+ "trace": (
33
+ "AttributeError: 'Trainer' object has no attribute 'start_training'. "
34
+ "Did you mean: 'train'?"
35
+ ),
36
+ "expect_in_repaired": "trainer.train()",
37
+ "expect_not_in_repaired": "start_training",
38
+ },
39
+ {
40
+ "name": "ModuleNotFoundError submodule",
41
+ "script": (
42
+ "import torch.legacy as torch\n"
43
+ "x = torch.randn(2, 3)\n"
44
+ "print(x)\n"
45
+ ),
46
+ "trace": "ModuleNotFoundError: No module named 'torch.legacy'",
47
+ "expect_in_repaired": "import torch",
48
+ "expect_not_in_repaired": "torch.legacy",
49
+ },
50
+ {
51
+ "name": "TypeError + use ... instead",
52
+ "script": (
53
+ "from transformers import AutoTokenizer\n"
54
+ "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
55
+ "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
56
+ "print(out)\n"
57
+ ),
58
+ "trace": (
59
+ "TypeError: __call__() got an unexpected keyword argument "
60
+ "'pad_to_max_length' (use `padding=True` instead)."
61
+ ),
62
+ "expect_in_repaired": "padding=True",
63
+ "expect_not_in_repaired": "pad_to_max_length",
64
+ },
65
+ ]
66
+
67
+
68
+ def run_one(case: dict) -> bool:
69
+ name = case["name"]
70
+ repaired, description = _heuristic_repair(case["script"], case["trace"])
71
+
72
+ ok_changed = repaired != case["script"]
73
+ ok_desc = bool(description)
74
+ ok_in = case["expect_in_repaired"] in repaired
75
+ ok_not = case["expect_not_in_repaired"] not in repaired
76
+
77
+ status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL"
78
+ print(f"[{status}] {name}")
79
+ print(f" description: {description!r}")
80
+ print(f" changed? {ok_changed}")
81
+ print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}")
82
+ print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}")
83
+ if status == "FAIL":
84
+ print(" --- repaired script ---")
85
+ print(repaired)
86
+ print(" -----------------------")
87
+ return status == "PASS"
88
+
89
+
90
+ def main() -> int:
91
+ results = [run_one(c) for c in CASES]
92
+ print()
93
+ n_pass = sum(results)
94
+ print(f"summary: {n_pass}/{len(results)} passed")
95
+ return 0 if all(results) else 1
96
+
97
+
98
+ if __name__ == "__main__":
99
+ sys.exit(main())
scripts/jobs/train_repair_agent.py CHANGED
@@ -206,11 +206,32 @@ from forgeenv.training.plots import ( # noqa: E402
206
  plot_success_rate_by_category,
207
  )
208
 
209
- trainer_state = GRPO_DIR / "trainer_state.json"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  training_rewards: list[float] = []
211
- if trainer_state.exists():
212
  state = json.loads(trainer_state.read_text())
213
- for log in state.get("log_history", []):
 
 
 
 
 
214
  # TRL emits a few different reward keys depending on version;
215
  # try the most specific first, then fall back.
216
  candidates = [
 
206
  plot_success_rate_by_category,
207
  )
208
 
209
+ # TRL writes trainer_state.json under each checkpoint dir, not directly
210
+ # at output_dir. Pick the latest checkpoint, fall back to output_dir.
211
+ def _find_trainer_state(grpo_dir: Path) -> Optional[Path]: # type: ignore[name-defined]
212
+ direct = grpo_dir / "trainer_state.json"
213
+ if direct.exists():
214
+ return direct
215
+ ckpts = sorted(
216
+ (p for p in grpo_dir.glob("checkpoint-*") if (p / "trainer_state.json").exists()),
217
+ key=lambda p: int(p.name.split("-")[-1]) if p.name.split("-")[-1].isdigit() else -1,
218
+ )
219
+ return (ckpts[-1] / "trainer_state.json") if ckpts else None
220
+
221
+
222
+ from typing import Optional # noqa: E402
223
+
224
+ trainer_state = _find_trainer_state(GRPO_DIR)
225
+ print(f"[job] trainer_state path: {trainer_state}", flush=True)
226
  training_rewards: list[float] = []
227
+ if trainer_state is not None and trainer_state.exists():
228
  state = json.loads(trainer_state.read_text())
229
+ log_history = state.get("log_history", [])
230
+ print(f"[job] log_history rows: {len(log_history)}", flush=True)
231
+ if log_history:
232
+ sample_keys = sorted(set().union(*(log.keys() for log in log_history)))
233
+ print(f"[job] log keys present: {sample_keys}", flush=True)
234
+ for log in log_history:
235
  # TRL emits a few different reward keys depending on version;
236
  # try the most specific first, then fall back.
237
  candidates = [
scripts/submit_training_job.py CHANGED
@@ -87,17 +87,23 @@ def submit_job(
87
  base_model: str,
88
  timeout: str,
89
  ) -> JobInfo:
90
- script_path = REPO_ROOT / "scripts" / "jobs" / "train_repair_agent.py"
91
- script = script_path.read_text(encoding="utf-8")
 
 
 
 
 
92
 
93
  job = api.run_uv_job(
94
- script=script,
95
  dependencies=[
96
  "huggingface_hub>=0.27",
97
  "requests",
98
  ],
99
  flavor=flavor,
100
  timeout=timeout,
 
101
  env={
102
  "HF_USERNAME": user,
103
  "ENV_URL": f"https://{user}-forgeenv.hf.space",
@@ -114,29 +120,42 @@ def submit_job(
114
  return job
115
 
116
 
117
- def tail_logs(api: HfApi, token: str, job_id: str) -> int:
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True)
119
- last_status = None
120
  try:
121
- for line in api.fetch_job_logs(job_id=job_id, token=token):
122
  print(line, flush=True)
123
  except KeyboardInterrupt:
124
  print("\n[launcher] log stream interrupted by user.", flush=True)
125
  except Exception as e: # noqa: BLE001
126
  print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True)
127
 
 
128
  while True:
129
- info = api.inspect_job(job_id=job_id, token=token)
130
- status = getattr(info, "status", None)
131
- if status != last_status:
132
- print(f"[launcher] status: {status}", flush=True)
133
- last_status = status
134
- if status in {"COMPLETED", "FAILED", "CANCELLED", "ERROR"}:
135
  break
136
- time.sleep(15)
137
 
138
- print(f"[launcher] final status: {last_status}", flush=True)
139
- return 0 if last_status == "COMPLETED" else 1
140
 
141
 
142
  def main() -> int:
@@ -176,7 +195,7 @@ def main() -> int:
176
 
177
  if args.no_tail:
178
  return 0
179
- return tail_logs(api, token, job_id)
180
 
181
 
182
  if __name__ == "__main__":
 
87
  base_model: str,
88
  timeout: str,
89
  ) -> JobInfo:
90
+ # The training script lives in the published source repo. Pass its
91
+ # raw Hub URL — `run_uv_job` accepts a URL/path/command, not the
92
+ # script body itself.
93
+ script_url = (
94
+ f"https://huggingface.co/{user}/forgeenv-source/"
95
+ "resolve/main/scripts/jobs/train_repair_agent.py"
96
+ )
97
 
98
  job = api.run_uv_job(
99
+ script=script_url,
100
  dependencies=[
101
  "huggingface_hub>=0.27",
102
  "requests",
103
  ],
104
  flavor=flavor,
105
  timeout=timeout,
106
+ namespace=user,
107
  env={
108
  "HF_USERNAME": user,
109
  "ENV_URL": f"https://{user}-forgeenv.hf.space",
 
120
  return job
121
 
122
 
123
+ _TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "DELETED"}
124
+
125
+
126
+ def _stage_of(info) -> str:
127
+ status = getattr(info, "status", None)
128
+ if status is None:
129
+ return "UNKNOWN"
130
+ stage = getattr(status, "stage", None)
131
+ if stage is None:
132
+ return str(status)
133
+ return str(stage)
134
+
135
+
136
+ def tail_logs(api: HfApi, token: str, job_id: str, namespace: str | None = None) -> int:
137
  print(f"\n[launcher] streaming logs for job {job_id} (Ctrl-C to stop tailing) ...\n", flush=True)
 
138
  try:
139
+ for line in api.fetch_job_logs(job_id=job_id, namespace=namespace, token=token):
140
  print(line, flush=True)
141
  except KeyboardInterrupt:
142
  print("\n[launcher] log stream interrupted by user.", flush=True)
143
  except Exception as e: # noqa: BLE001
144
  print(f"\n[launcher] log stream ended ({e}); polling status ...", flush=True)
145
 
146
+ last_stage: str | None = None
147
  while True:
148
+ info = api.inspect_job(job_id=job_id, namespace=namespace, token=token)
149
+ stage = _stage_of(info)
150
+ if stage != last_stage:
151
+ print(f"[launcher] status: {stage}", flush=True)
152
+ last_stage = stage
153
+ if stage in _TERMINAL_STAGES:
154
  break
155
+ time.sleep(20)
156
 
157
+ print(f"[launcher] final status: {last_stage}", flush=True)
158
+ return 0 if last_stage == "COMPLETED" else 1
159
 
160
 
161
  def main() -> int:
 
195
 
196
  if args.no_tail:
197
  return 0
198
+ return tail_logs(api, token, job_id, namespace=args.user)
199
 
200
 
201
  if __name__ == "__main__":
scripts/tail_training_job.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Re-attach to an in-flight HF Jobs run and stream its logs.
3
+
4
+ Usage::
5
+
6
+ $env:HF_TOKEN = "hf_..."
7
+ python scripts/tail_training_job.py 69ec88dfd70108f37acde39d
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import sys
13
+
14
+ from huggingface_hub import HfApi
15
+
16
+ from submit_training_job import tail_logs # type: ignore[import-not-found]
17
+
18
+
19
+ def main() -> int:
20
+ if len(sys.argv) < 2:
21
+ print("usage: python scripts/tail_training_job.py <job_id> [namespace]", file=sys.stderr)
22
+ return 2
23
+ job_id = sys.argv[1]
24
+ namespace = sys.argv[2] if len(sys.argv) > 2 else "akhiilll"
25
+ token = os.environ.get("HF_TOKEN")
26
+ if not token:
27
+ print("ERROR: set HF_TOKEN in the environment first.", file=sys.stderr)
28
+ return 2
29
+ api = HfApi()
30
+ return tail_logs(api, token, job_id, namespace=namespace)
31
+
32
+
33
+ if __name__ == "__main__":
34
+ raise SystemExit(main())
scripts/test_live_env.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke-test the live ForgeEnv Space end-to-end via the OpenEnv client.
2
+
3
+ Runs one full episode against the deployed Space:
4
+
5
+ reset() -> drift-gen turn
6
+ step(DriftAction) -> repair turn
7
+ step(RepairAction) -> reward + verifier breakdown
8
+
9
+ This is the simplest possible "is the deployed env working?" check
10
+ and a clean standalone artifact for the hackathon writeup/video.
11
+
12
+ Usage::
13
+
14
+ python scripts/test_live_env.py
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import json
20
+
21
+ from openenv.core import GenericAction, GenericEnvClient
22
+
23
+ ENV_URL = "https://akhiilll-forgeenv.hf.space"
24
+
25
+
26
+ def _summary(result, label: str) -> None:
27
+ obs = result.observation if isinstance(result.observation, dict) else {}
28
+ print(f"\n=== {label} ===")
29
+ print(f"phase : {obs.get('current_phase')}")
30
+ print(f"task_id : {obs.get('task_id')}")
31
+ print(f"target_category : {obs.get('target_category')}")
32
+ print(f"reward : {result.reward}")
33
+ print(f"done : {result.done}")
34
+ breakdown = obs.get("reward_breakdown")
35
+ if breakdown:
36
+ print("reward_breakdown:")
37
+ print(json.dumps(breakdown, indent=2))
38
+ script = obs.get("script_content") or obs.get("broken_script") or ""
39
+ if script:
40
+ preview = script.splitlines()[:8]
41
+ print("script preview :")
42
+ for line in preview:
43
+ print(f" | {line}")
44
+ if len(script.splitlines()) > 8:
45
+ print(" | ...")
46
+
47
+
48
+ async def main(seed: int = 42) -> None:
49
+ print(f"connecting to {ENV_URL} (seed={seed}) ...")
50
+ client = GenericEnvClient(base_url=ENV_URL)
51
+
52
+ res = await client.reset(seed=seed, options={"difficulty": "medium"})
53
+ _summary(res, "after reset()")
54
+ target = res.observation.get("target_category", "RenameApiCall") if isinstance(res.observation, dict) else "RenameApiCall"
55
+
56
+ res = await client.step(GenericAction(
57
+ breakage={"action_type": "breakage", "primitive_type": target, "params": {}},
58
+ repair=None,
59
+ ))
60
+ _summary(res, "after drift step (Challenger)")
61
+
62
+ # empty diff = no-op repair: shows the verifier marking the script as still broken
63
+ res = await client.step(GenericAction(
64
+ breakage=None,
65
+ repair={"action_type": "repair", "unified_diff": ""},
66
+ ))
67
+ _summary(res, "after repair step (Solver, no-op)")
68
+
69
+ print("\nOK -- reset + 2 steps round-trip the deployed env.")
70
+
71
+
72
+ if __name__ == "__main__":
73
+ import sys
74
+
75
+ seed = int(sys.argv[1]) if len(sys.argv) > 1 else 42
76
+ asyncio.run(main(seed=seed))
scripts/test_repair_agent.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smoke-test the trained Repair Agent locally on one episode.
2
+
3
+ Loads the LoRA adapter pushed to ``akhiilll/forgeenv-repair-agent``, hits
4
+ the live ForgeEnv Space for a fresh broken script, asks the model to
5
+ emit a unified diff, applies it, and prints the verifier breakdown.
6
+
7
+ Usage::
8
+
9
+ python scripts/test_repair_agent.py --seed 7
10
+ python scripts/test_repair_agent.py --seed 7 --base-model unsloth/Qwen2.5-Coder-1.5B-Instruct
11
+
12
+ Requires GPU + transformers/peft. Skip this if you only want a quick
13
+ demo -- use ``scripts/test_live_env.py`` or the Gradio Space instead.
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import asyncio
19
+ import json
20
+
21
+ from openenv.core import GenericAction, GenericEnvClient
22
+
23
+ ENV_URL = "https://akhiilll-forgeenv.hf.space"
24
+ LORA_REPO = "akhiilll/forgeenv-repair-agent"
25
+
26
+ REPAIR_PROMPT = """\
27
+ You are a senior ML engineer fixing a HuggingFace training script that just broke.
28
+ Output ONLY a unified diff (`--- a/script.py` / `+++ b/script.py`) that fixes the
29
+ breakage signaled by the error trace. No prose, no fences, no explanation.
30
+
31
+ # Broken script
32
+ ```python
33
+ {script}
34
+ ```
35
+
36
+ # Error trace
37
+ ```
38
+ {error}
39
+ ```
40
+
41
+ # Diff
42
+ """
43
+
44
+
45
+ async def fetch_broken_episode(seed: int):
46
+ client = GenericEnvClient(base_url=ENV_URL)
47
+ res = await client.reset(seed=seed, options={"difficulty": "medium"})
48
+ target = res.observation["target_category"]
49
+ res = await client.step(GenericAction(
50
+ breakage={"action_type": "breakage", "primitive_type": target, "params": {}},
51
+ repair=None,
52
+ ))
53
+ obs = res.observation
54
+ return client, obs.get("script_content") or obs.get("broken_script") or "", obs.get("error_trace", "")
55
+
56
+
57
+ async def submit_repair(client: GenericEnvClient, diff: str):
58
+ res = await client.step(GenericAction(
59
+ breakage=None,
60
+ repair={"action_type": "repair", "unified_diff": diff},
61
+ ))
62
+ return res
63
+
64
+
65
+ def generate_diff(base_model: str, lora_repo: str, prompt: str) -> str:
66
+ import torch
67
+ from peft import PeftModel
68
+ from transformers import AutoModelForCausalLM, AutoTokenizer
69
+
70
+ print(f"loading base model: {base_model}")
71
+ tok = AutoTokenizer.from_pretrained(base_model)
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ base_model,
74
+ torch_dtype=torch.bfloat16,
75
+ device_map="auto",
76
+ )
77
+ print(f"attaching LoRA: {lora_repo}")
78
+ model = PeftModel.from_pretrained(model, lora_repo)
79
+ model.eval()
80
+
81
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
82
+ with torch.no_grad():
83
+ out = model.generate(
84
+ **inputs,
85
+ max_new_tokens=512,
86
+ do_sample=False,
87
+ temperature=0.0,
88
+ pad_token_id=tok.eos_token_id,
89
+ )
90
+ text = tok.decode(out[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
91
+ return text.strip()
92
+
93
+
94
+ async def main(args) -> None:
95
+ print(f"--- pulling broken episode (seed={args.seed}) from {ENV_URL}")
96
+ client, broken_script, error_trace = await fetch_broken_episode(args.seed)
97
+ if not broken_script:
98
+ raise SystemExit("env returned empty script_content; pick a different seed")
99
+ print(f"broken script length: {len(broken_script)} chars")
100
+ print(f"error trace : {(error_trace[:200] + '...') if len(error_trace) > 200 else error_trace}")
101
+
102
+ prompt = REPAIR_PROMPT.format(script=broken_script, error=error_trace or "<env did not surface a trace>")
103
+ diff = generate_diff(args.base_model, args.lora_repo, prompt)
104
+
105
+ print("\n=== model diff ===")
106
+ print(diff)
107
+
108
+ print("\n=== submitting diff to env ===")
109
+ res = await submit_repair(client, diff)
110
+ print(f"reward: {res.reward} done: {res.done}")
111
+ breakdown = res.observation.get("reward_breakdown") if isinstance(res.observation, dict) else None
112
+ if breakdown:
113
+ print("reward_breakdown:")
114
+ print(json.dumps(breakdown, indent=2))
115
+
116
+
117
+ if __name__ == "__main__":
118
+ p = argparse.ArgumentParser()
119
+ p.add_argument("--seed", type=int, default=7)
120
+ p.add_argument("--base-model", default="unsloth/Qwen2.5-Coder-1.5B-Instruct")
121
+ p.add_argument("--lora-repo", default=LORA_REPO)
122
+ args = p.parse_args()
123
+ asyncio.run(main(args))