| """Gradio demo Space for the ForgeEnv Repair Agent. |
| |
| Three-tier repair pipeline so the demo always returns a useful diff: |
| |
| 1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model |
| emits a diff that, when applied, actually changes the broken script, |
| we use it. |
| 2. **Error-trace heuristic** — extracts the fix signal from the Python |
| traceback (Did you mean / unexpected kwarg / No module named) and |
| emits a clean canonical diff. Handles the most common drift patterns. |
| 3. **Model reasoning hint** — if heuristic fails, surface the model's |
| natural-language reasoning (it usually explains the bug correctly even |
| when its diff syntax is broken) alongside a "no patch produced" note. |
| |
| This separation means the demo is robust regardless of how well the |
| LoRA generalises on a given input — and it's honest about what each |
| component contributed. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import traceback |
| from typing import Optional |
|
|
| import gradio as gr |
|
|
| BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") |
| ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent") |
|
|
| _TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift" |
| _DESCRIPTION = ( |
| "Paste a broken HuggingFace training script and the error trace it " |
| "produced. The Repair Agent returns a minimal unified diff. The model " |
| "was trained inside [ForgeEnv](https://huggingface.co/spaces/" |
| "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style " |
| "Challenger / Solver co-evolution. The agent is backed by a heuristic " |
| "fallback that parses error traces directly when the LoRA's diff is " |
| "malformed — keeps the demo robust on out-of-distribution inputs." |
| ) |
|
|
| _EXAMPLES = [ |
| [ |
| ( |
| "from transformers import Trainer, TrainingArguments\n" |
| "from datasets import load_dataset\n\n" |
| "ds = load_dataset('glue', 'sst2')\n" |
| "args = TrainingArguments(output_dir='out')\n" |
| "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" |
| "trainer.start_training()\n" |
| ), |
| ( |
| "AttributeError: 'Trainer' object has no attribute 'start_training'. " |
| "Did you mean: 'train'?" |
| ), |
| ], |
| [ |
| ( |
| "import torch.legacy as torch\n" |
| "x = torch.randn(2, 3)\n" |
| "print(x)\n" |
| ), |
| "ModuleNotFoundError: No module named 'torch.legacy'", |
| ], |
| [ |
| ( |
| "from transformers import AutoTokenizer\n" |
| "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" |
| "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" |
| "print(out)\n" |
| ), |
| ( |
| "TypeError: __call__() got an unexpected keyword argument " |
| "'pad_to_max_length' (use `padding=True` instead)." |
| ), |
| ], |
| ] |
|
|
| _PROMPT_TEMPLATE = ( |
| "You are an expert ML engineer who fixes broken HuggingFace training " |
| "scripts caused by library version drift.\n\n" |
| "Library versions: {versions}\n\n" |
| "Broken script:\n```python\n{script}\n```\n\n" |
| "Error trace:\n```\n{trace}\n```\n\n" |
| "Output ONLY a minimal unified diff (`--- a/script.py` / `+++ " |
| "b/script.py` headers, then hunks). No prose." |
| ) |
|
|
| _model = None |
| _tokenizer = None |
| _load_error: Optional[str] = None |
|
|
|
|
| |
| def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool: |
| """Cheap pre-check: pull adapter_config.json and compare base_model_name.""" |
| try: |
| from huggingface_hub import hf_hub_download |
|
|
| cfg_path = hf_hub_download( |
| repo_id=adapter_repo, |
| filename="adapter_config.json", |
| token=os.environ.get("HF_TOKEN"), |
| ) |
| with open(cfg_path) as f: |
| cfg = json.load(f) |
| adapter_base = (cfg.get("base_model_name_or_path") or "").lower() |
| |
| |
| family = base_name.split("/")[-1].lower().replace("-instruct", "") |
| return family in adapter_base |
| except Exception as e: |
| print(f"[demo] adapter_config check failed ({e}); attempting load anyway") |
| return True |
|
|
|
|
| def _load_model() -> None: |
| """Lazy-load the trained LoRA on first GPU invocation.""" |
| global _model, _tokenizer, _load_error |
| if _model is not None or _load_error is not None: |
| return |
| try: |
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| ) |
| if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL): |
| try: |
| model = PeftModel.from_pretrained(base, ADAPTER_REPO) |
| print(f"[demo] LoRA attached: {ADAPTER_REPO}") |
| except Exception as e: |
| print(f"[demo] adapter load failed ({e}); using base model") |
| model = base |
| else: |
| print( |
| f"[demo] adapter at {ADAPTER_REPO} was trained on a different " |
| f"base; using {BASE_MODEL} alone until matching adapter ships" |
| ) |
| model = base |
| _model = model.eval() |
| _tokenizer = tokenizer |
| except Exception as e: |
| _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" |
|
|
|
|
| _SYSTEM_PROMPT = ( |
| "You are an expert ML engineer who fixes broken HuggingFace training " |
| "scripts caused by library version drift. Output ONLY a unified diff." |
| ) |
|
|
|
|
| def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str: |
| """Greedy decode using the base model's chat template (Qwen ChatML).""" |
| import torch |
|
|
| messages = [ |
| {"role": "system", "content": _SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ] |
| try: |
| text = _tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| except Exception: |
| text = prompt |
| inputs = _tokenizer(text, return_tensors="pt").to(_model.device) |
| with torch.no_grad(): |
| out = _model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| temperature=0.0, |
| repetition_penalty=1.15, |
| pad_token_id=_tokenizer.eos_token_id, |
| ) |
| completion = _tokenizer.decode( |
| out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True |
| ) |
| return completion.strip() |
|
|
|
|
| |
| _FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE) |
| _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) |
|
|
|
|
| def _extract_diff_block(raw: str) -> str: |
| """Pull the *first* fenced diff out of the model's raw output.""" |
| if not raw: |
| return "" |
| m = _FENCE_RE.search(raw) |
| if m: |
| return m.group(1).strip() |
| |
| for marker in ("--- ", "+++ ", "@@"): |
| idx = raw.find(marker) |
| if idx >= 0: |
| return raw[idx:].strip() |
| return "" |
|
|
|
|
| def _diff_actually_changes_script(broken: str, diff_text: str) -> bool: |
| """Try to apply the diff. Returns True iff the result differs from input.""" |
| if not diff_text: |
| return False |
| try: |
| from forgeenv.env.diff_utils import apply_unified_diff |
|
|
| repaired = apply_unified_diff(broken, diff_text) |
| return bool(repaired) and repaired.strip() != broken.strip() |
| except Exception: |
| return False |
|
|
|
|
| def _canonicalise(broken: str, diff_text: str) -> str: |
| """Apply diff -> rebuild a clean canonical unified diff.""" |
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff |
|
|
| repaired = apply_unified_diff(broken, diff_text) |
| if not repaired or repaired.strip() == broken.strip(): |
| return "" |
| return make_unified_diff(broken, repaired) |
|
|
|
|
| def _extract_model_reasoning(raw: str) -> str: |
| """Pull the natural-language reasoning out of the model's output (if any).""" |
| if not raw: |
| return "" |
| text = re.sub(_FENCE_RE, "", raw).strip() |
| text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip() |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip()] |
| sentences: list[str] = [] |
| for ln in lines: |
| if ln.startswith(("---", "+++", "@@", "-", "+")): |
| continue |
| if len(ln) < 10: |
| continue |
| sentences.append(ln) |
| if len(sentences) >= 3: |
| break |
| return " ".join(sentences) |
|
|
|
|
| |
| _DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE) |
| _NO_ATTR_RE = re.compile( |
| r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE |
| ) |
| _NO_MODULE_RE = re.compile( |
| r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE |
| ) |
| _BAD_KWARG_RE = re.compile( |
| r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE |
| ) |
| _USE_INSTEAD_RE = re.compile( |
| r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE |
| ) |
|
|
|
|
| def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]: |
| """Produce a (repaired_script, fix_description) pair from the trace. |
| |
| Patterns covered: |
| * AttributeError + "Did you mean: 'X'?" -> rename method |
| * AttributeError without hint -> remove the call (rarely useful) |
| * ModuleNotFoundError 'X.Y' -> drop the .Y submodule |
| * TypeError unexpected kwarg + 'use Y' -> swap kwarg |
| * TypeError unexpected kwarg, no hint -> drop the kwarg |
| """ |
| if not error_trace: |
| return broken, "" |
| trace = error_trace.strip() |
| repaired = broken |
| description = "" |
|
|
| |
| if "AttributeError" in trace or "has no attribute" in trace: |
| old = _NO_ATTR_RE.search(trace) |
| new = _DID_YOU_MEAN_RE.search(trace) |
| if old and new and old.group(1) != new.group(1): |
| old_name, new_name = old.group(1), new.group(1) |
| pattern = re.compile(rf"\b{re.escape(old_name)}\b") |
| if pattern.search(repaired): |
| repaired = pattern.sub(new_name, repaired) |
| description = ( |
| f"`{old_name}` is no longer an attribute on this object; " |
| f"renamed call to `{new_name}` per the traceback hint." |
| ) |
|
|
| |
| if not description and "No module named" in trace: |
| m = _NO_MODULE_RE.search(trace) |
| if m: |
| mod = m.group(1) |
| if "." in mod: |
| parent, child = mod.rsplit(".", 1) |
| pat_full = re.compile(rf"\b{re.escape(mod)}\b") |
| if pat_full.search(repaired): |
| repaired = pat_full.sub(parent, repaired) |
| description = ( |
| f"`{mod}` was removed; replaced with parent module " |
| f"`{parent}`." |
| ) |
|
|
| |
| if not description and "unexpected keyword argument" in trace: |
| bad = _BAD_KWARG_RE.search(trace) |
| good = _USE_INSTEAD_RE.search(trace) |
| if bad: |
| bad_kw = bad.group(1) |
| if good: |
| good_kw = good.group(1) |
| pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=") |
| if pat.search(repaired): |
| repaired = pat.sub(f"{good_kw}=", repaired) |
| |
| |
| description = ( |
| f"`{bad_kw}` was renamed to `{good_kw}`; updated " |
| f"keyword to match the new API." |
| ) |
| else: |
| |
| pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+") |
| if pat.search(repaired): |
| repaired = pat.sub("", repaired) |
| description = ( |
| f"`{bad_kw}` is no longer accepted; removed the " |
| f"keyword argument." |
| ) |
|
|
| return repaired, description |
|
|
|
|
| |
| try: |
| import spaces |
|
|
| _gpu_decorator = spaces.GPU(duration=60) |
| except Exception: |
| def _gpu_decorator(fn): |
| return fn |
|
|
|
|
| @_gpu_decorator |
| def repair_script(script: str, error_trace: str) -> str: |
| if not script.strip(): |
| return "# Paste a broken script first." |
|
|
| |
| model_raw = "" |
| model_diff_canonical = "" |
| model_reasoning = "" |
|
|
| _load_model() |
| if _model is not None: |
| try: |
| versions = json.dumps( |
| {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"} |
| ) |
| prompt = _PROMPT_TEMPLATE.format( |
| versions=versions, |
| script=script, |
| trace=error_trace or "(no trace)", |
| ) |
| model_raw = _generate_with_model(prompt) |
| model_diff_text = _extract_diff_block(model_raw) |
| if _diff_actually_changes_script(script, model_diff_text): |
| model_diff_canonical = _canonicalise(script, model_diff_text) |
| model_reasoning = _extract_model_reasoning(model_raw) |
| except Exception as e: |
| print(f"[demo] model generation failed: {e}") |
|
|
| if model_diff_canonical: |
| header = ( |
| "# Source: trained LoRA (ForgeEnv GRPO adapter)\n" |
| "# The model produced a valid diff that successfully patches the script.\n" |
| ) |
| return header + "\n" + model_diff_canonical |
|
|
| |
| repaired, description = _heuristic_repair(script, error_trace) |
| if description and repaired != script: |
| from forgeenv.env.diff_utils import make_unified_diff |
|
|
| diff = make_unified_diff(script, repaired) |
| header_lines = [ |
| "# Source: error-trace heuristic (LoRA diff was malformed; " |
| "fell back to deterministic repair).", |
| f"# Fix: {description}", |
| ] |
| if model_reasoning: |
| header_lines.append(f"# Trained model said: {model_reasoning}") |
| return "\n".join(header_lines) + "\n\n" + diff |
|
|
| |
| msg_lines = ["# Could not produce a confident patch."] |
| if model_reasoning: |
| msg_lines.append(f"# Trained model reasoning: {model_reasoning}") |
| if error_trace: |
| msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}") |
| msg_lines.append( |
| "# Try a more specific error trace (the heuristic looks for " |
| "'Did you mean', 'No module named', or 'unexpected keyword argument')." |
| ) |
| return "\n".join(msg_lines) |
|
|
|
|
| |
| with gr.Blocks(title="ForgeEnv Repair Agent") as demo: |
| gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}") |
| with gr.Row(): |
| with gr.Column(): |
| in_script = gr.Code( |
| label="Broken HuggingFace script", |
| language="python", |
| lines=22, |
| ) |
| in_trace = gr.Textbox( |
| label="Error trace", |
| lines=6, |
| placeholder="Traceback...", |
| ) |
| run_btn = gr.Button("Repair", variant="primary") |
| with gr.Column(): |
| out_diff = gr.Code( |
| label="Suggested repair (unified diff)", |
| language="markdown", |
| lines=22, |
| ) |
|
|
| gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace]) |
| run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|