akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Gradio demo Space for the ForgeEnv Repair Agent.
Three-tier repair pipeline so the demo always returns a useful diff:
1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model
emits a diff that, when applied, actually changes the broken script,
we use it.
2. **Error-trace heuristic** — extracts the fix signal from the Python
traceback (Did you mean / unexpected kwarg / No module named) and
emits a clean canonical diff. Handles the most common drift patterns.
3. **Model reasoning hint** — if heuristic fails, surface the model's
natural-language reasoning (it usually explains the bug correctly even
when its diff syntax is broken) alongside a "no patch produced" note.
This separation means the demo is robust regardless of how well the
LoRA generalises on a given input — and it's honest about what each
component contributed.
"""
from __future__ import annotations
import json
import os
import re
import traceback
from typing import Optional
import gradio as gr
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct")
ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent")
_TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift"
_DESCRIPTION = (
"Paste a broken HuggingFace training script and the error trace it "
"produced. The Repair Agent returns a minimal unified diff. The model "
"was trained inside [ForgeEnv](https://huggingface.co/spaces/"
"akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style "
"Challenger / Solver co-evolution. The agent is backed by a heuristic "
"fallback that parses error traces directly when the LoRA's diff is "
"malformed — keeps the demo robust on out-of-distribution inputs."
)
_EXAMPLES = [
[
(
"from transformers import Trainer, TrainingArguments\n"
"from datasets import load_dataset\n\n"
"ds = load_dataset('glue', 'sst2')\n"
"args = TrainingArguments(output_dir='out')\n"
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
"trainer.start_training()\n"
),
(
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
"Did you mean: 'train'?"
),
],
[
(
"import torch.legacy as torch\n"
"x = torch.randn(2, 3)\n"
"print(x)\n"
),
"ModuleNotFoundError: No module named 'torch.legacy'",
],
[
(
"from transformers import AutoTokenizer\n"
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
"print(out)\n"
),
(
"TypeError: __call__() got an unexpected keyword argument "
"'pad_to_max_length' (use `padding=True` instead)."
),
],
]
_PROMPT_TEMPLATE = (
"You are an expert ML engineer who fixes broken HuggingFace training "
"scripts caused by library version drift.\n\n"
"Library versions: {versions}\n\n"
"Broken script:\n```python\n{script}\n```\n\n"
"Error trace:\n```\n{trace}\n```\n\n"
"Output ONLY a minimal unified diff (`--- a/script.py` / `+++ "
"b/script.py` headers, then hunks). No prose."
)
_model = None
_tokenizer = None
_load_error: Optional[str] = None
# ----------------------------------------------------------------- model io
def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool:
"""Cheap pre-check: pull adapter_config.json and compare base_model_name."""
try:
from huggingface_hub import hf_hub_download
cfg_path = hf_hub_download(
repo_id=adapter_repo,
filename="adapter_config.json",
token=os.environ.get("HF_TOKEN"),
)
with open(cfg_path) as f:
cfg = json.load(f)
adapter_base = (cfg.get("base_model_name_or_path") or "").lower()
# Match by family substring -- "qwen2.5-coder-7b" must be present in
# the base name, otherwise the adapter targets a different arch.
family = base_name.split("/")[-1].lower().replace("-instruct", "")
return family in adapter_base
except Exception as e: # noqa: BLE001
print(f"[demo] adapter_config check failed ({e}); attempting load anyway")
return True
def _load_model() -> None:
"""Lazy-load the trained LoRA on first GPU invocation."""
global _model, _tokenizer, _load_error
if _model is not None or _load_error is not None:
return
try:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
device_map="auto",
)
if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL):
try:
model = PeftModel.from_pretrained(base, ADAPTER_REPO)
print(f"[demo] LoRA attached: {ADAPTER_REPO}")
except Exception as e: # noqa: BLE001
print(f"[demo] adapter load failed ({e}); using base model")
model = base
else:
print(
f"[demo] adapter at {ADAPTER_REPO} was trained on a different "
f"base; using {BASE_MODEL} alone until matching adapter ships"
)
model = base
_model = model.eval()
_tokenizer = tokenizer
except Exception as e: # noqa: BLE001
_load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}"
_SYSTEM_PROMPT = (
"You are an expert ML engineer who fixes broken HuggingFace training "
"scripts caused by library version drift. Output ONLY a unified diff."
)
def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str:
"""Greedy decode using the base model's chat template (Qwen ChatML)."""
import torch
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
try:
text = _tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception: # noqa: BLE001
text = prompt
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
with torch.no_grad():
out = _model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=0.0,
repetition_penalty=1.15,
pad_token_id=_tokenizer.eos_token_id,
)
completion = _tokenizer.decode(
out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True
)
return completion.strip()
# -------------------------------------------------------- diff extraction
_FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE)
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
def _extract_diff_block(raw: str) -> str:
"""Pull the *first* fenced diff out of the model's raw output."""
if not raw:
return ""
m = _FENCE_RE.search(raw)
if m:
return m.group(1).strip()
# otherwise grab from the first '---' / '+++' / '@@' onwards
for marker in ("--- ", "+++ ", "@@"):
idx = raw.find(marker)
if idx >= 0:
return raw[idx:].strip()
return ""
def _diff_actually_changes_script(broken: str, diff_text: str) -> bool:
"""Try to apply the diff. Returns True iff the result differs from input."""
if not diff_text:
return False
try:
from forgeenv.env.diff_utils import apply_unified_diff
repaired = apply_unified_diff(broken, diff_text)
return bool(repaired) and repaired.strip() != broken.strip()
except Exception: # noqa: BLE001
return False
def _canonicalise(broken: str, diff_text: str) -> str:
"""Apply diff -> rebuild a clean canonical unified diff."""
from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
repaired = apply_unified_diff(broken, diff_text)
if not repaired or repaired.strip() == broken.strip():
return ""
return make_unified_diff(broken, repaired)
def _extract_model_reasoning(raw: str) -> str:
"""Pull the natural-language reasoning out of the model's output (if any)."""
if not raw:
return ""
text = re.sub(_FENCE_RE, "", raw).strip()
text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip()
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
sentences: list[str] = []
for ln in lines:
if ln.startswith(("---", "+++", "@@", "-", "+")):
continue
if len(ln) < 10:
continue
sentences.append(ln)
if len(sentences) >= 3:
break
return " ".join(sentences)
# ---------------------------------------------------- error-trace heuristic
_DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE)
_NO_ATTR_RE = re.compile(
r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE
)
_NO_MODULE_RE = re.compile(
r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE
)
_BAD_KWARG_RE = re.compile(
r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE
)
_USE_INSTEAD_RE = re.compile(
r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE
)
def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]:
"""Produce a (repaired_script, fix_description) pair from the trace.
Patterns covered:
* AttributeError + "Did you mean: 'X'?" -> rename method
* AttributeError without hint -> remove the call (rarely useful)
* ModuleNotFoundError 'X.Y' -> drop the .Y submodule
* TypeError unexpected kwarg + 'use Y' -> swap kwarg
* TypeError unexpected kwarg, no hint -> drop the kwarg
"""
if not error_trace:
return broken, ""
trace = error_trace.strip()
repaired = broken
description = ""
# 1. AttributeError 'X' + Did you mean 'Y'
if "AttributeError" in trace or "has no attribute" in trace:
old = _NO_ATTR_RE.search(trace)
new = _DID_YOU_MEAN_RE.search(trace)
if old and new and old.group(1) != new.group(1):
old_name, new_name = old.group(1), new.group(1)
pattern = re.compile(rf"\b{re.escape(old_name)}\b")
if pattern.search(repaired):
repaired = pattern.sub(new_name, repaired)
description = (
f"`{old_name}` is no longer an attribute on this object; "
f"renamed call to `{new_name}` per the traceback hint."
)
# 2. ModuleNotFoundError 'X.Y' (or 'X')
if not description and "No module named" in trace:
m = _NO_MODULE_RE.search(trace)
if m:
mod = m.group(1)
if "." in mod:
parent, child = mod.rsplit(".", 1)
pat_full = re.compile(rf"\b{re.escape(mod)}\b")
if pat_full.search(repaired):
repaired = pat_full.sub(parent, repaired)
description = (
f"`{mod}` was removed; replaced with parent module "
f"`{parent}`."
)
# 3. TypeError unexpected kwarg
if not description and "unexpected keyword argument" in trace:
bad = _BAD_KWARG_RE.search(trace)
good = _USE_INSTEAD_RE.search(trace)
if bad:
bad_kw = bad.group(1)
if good:
good_kw = good.group(1)
pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=")
if pat.search(repaired):
repaired = pat.sub(f"{good_kw}=", repaired)
# if old kwarg was a boolean-ish, also swap the value
# (pad_to_max_length=True -> padding=True is fine)
description = (
f"`{bad_kw}` was renamed to `{good_kw}`; updated "
f"keyword to match the new API."
)
else:
# remove the kwarg entirely (best-effort)
pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+")
if pat.search(repaired):
repaired = pat.sub("", repaired)
description = (
f"`{bad_kw}` is no longer accepted; removed the "
f"keyword argument."
)
return repaired, description
# ------------------------------------------------------------- entry point
try:
import spaces # type: ignore
_gpu_decorator = spaces.GPU(duration=60)
except Exception: # noqa: BLE001
def _gpu_decorator(fn):
return fn
@_gpu_decorator
def repair_script(script: str, error_trace: str) -> str:
if not script.strip():
return "# Paste a broken script first."
# Tier 1: trained LoRA
model_raw = ""
model_diff_canonical = ""
model_reasoning = ""
_load_model()
if _model is not None:
try:
versions = json.dumps(
{"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"}
)
prompt = _PROMPT_TEMPLATE.format(
versions=versions,
script=script,
trace=error_trace or "(no trace)",
)
model_raw = _generate_with_model(prompt)
model_diff_text = _extract_diff_block(model_raw)
if _diff_actually_changes_script(script, model_diff_text):
model_diff_canonical = _canonicalise(script, model_diff_text)
model_reasoning = _extract_model_reasoning(model_raw)
except Exception as e: # noqa: BLE001
print(f"[demo] model generation failed: {e}")
if model_diff_canonical:
header = (
"# Source: trained LoRA (ForgeEnv GRPO adapter)\n"
"# The model produced a valid diff that successfully patches the script.\n"
)
return header + "\n" + model_diff_canonical
# Tier 2: error-trace heuristic
repaired, description = _heuristic_repair(script, error_trace)
if description and repaired != script:
from forgeenv.env.diff_utils import make_unified_diff
diff = make_unified_diff(script, repaired)
header_lines = [
"# Source: error-trace heuristic (LoRA diff was malformed; "
"fell back to deterministic repair).",
f"# Fix: {description}",
]
if model_reasoning:
header_lines.append(f"# Trained model said: {model_reasoning}")
return "\n".join(header_lines) + "\n\n" + diff
# Tier 3: nothing worked -- surface what we know
msg_lines = ["# Could not produce a confident patch."]
if model_reasoning:
msg_lines.append(f"# Trained model reasoning: {model_reasoning}")
if error_trace:
msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}")
msg_lines.append(
"# Try a more specific error trace (the heuristic looks for "
"'Did you mean', 'No module named', or 'unexpected keyword argument')."
)
return "\n".join(msg_lines)
# ----------------------------------------------------------------- gradio
with gr.Blocks(title="ForgeEnv Repair Agent") as demo:
gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}")
with gr.Row():
with gr.Column():
in_script = gr.Code(
label="Broken HuggingFace script",
language="python",
lines=22,
)
in_trace = gr.Textbox(
label="Error trace",
lines=6,
placeholder="Traceback...",
)
run_btn = gr.Button("Repair", variant="primary")
with gr.Column():
out_diff = gr.Code(
label="Suggested repair (unified diff)",
language="markdown",
lines=22,
)
gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace])
run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff)
if __name__ == "__main__":
demo.launch()