"""Gradio demo Space for the ForgeEnv Repair Agent. Three-tier repair pipeline so the demo always returns a useful diff: 1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model emits a diff that, when applied, actually changes the broken script, we use it. 2. **Error-trace heuristic** — extracts the fix signal from the Python traceback (Did you mean / unexpected kwarg / No module named) and emits a clean canonical diff. Handles the most common drift patterns. 3. **Model reasoning hint** — if heuristic fails, surface the model's natural-language reasoning (it usually explains the bug correctly even when its diff syntax is broken) alongside a "no patch produced" note. This separation means the demo is robust regardless of how well the LoRA generalises on a given input — and it's honest about what each component contributed. """ from __future__ import annotations import json import os import re import traceback from typing import Optional import gradio as gr BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent") ENABLE_MODEL = os.environ.get("ENABLE_MODEL", "0").strip().lower() in {"1", "true", "yes", "y"} _TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift" _DESCRIPTION = ( "Paste a broken HuggingFace training script and the error trace it " "produced. The Repair Agent returns a minimal unified diff. The model " "was trained inside [ForgeEnv](https://huggingface.co/spaces/" "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style " "Challenger / Solver co-evolution. The agent is backed by a heuristic " "fallback that parses error traces directly when the LoRA's diff is " "malformed — keeps the demo robust on out-of-distribution inputs." ) _EXAMPLES = [ [ ( "from transformers import Trainer, TrainingArguments\n" "from datasets import load_dataset\n\n" "ds = load_dataset('glue', 'sst2')\n" "args = TrainingArguments(output_dir='out')\n" "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" "trainer.start_training()\n" ), ( "AttributeError: 'Trainer' object has no attribute 'start_training'. " "Did you mean: 'train'?" ), ], [ ( "import torch.legacy as torch\n" "x = torch.randn(2, 3)\n" "print(x)\n" ), "ModuleNotFoundError: No module named 'torch.legacy'", ], [ ( "from transformers import AutoTokenizer\n" "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" "print(out)\n" ), ( "TypeError: __call__() got an unexpected keyword argument " "'pad_to_max_length' (use `padding=True` instead)." ), ], ] _PROMPT_TEMPLATE = ( "You are an expert ML engineer who fixes broken HuggingFace training " "scripts caused by library version drift.\n\n" "Library versions: {versions}\n\n" "Broken script:\n```python\n{script}\n```\n\n" "Error trace:\n```\n{trace}\n```\n\n" "Output ONLY a minimal unified diff (`--- a/script.py` / `+++ " "b/script.py` headers, then hunks). No prose." ) _model = None _tokenizer = None _load_error: Optional[str] = None # Minimal diff helpers (avoid requiring forgeenv in the Space runtime). def _make_unified_diff(before: str, after: str, path: str = "script.py") -> str: import difflib diff = difflib.unified_diff( before.splitlines(keepends=True), after.splitlines(keepends=True), fromfile=f"a/{path}", tofile=f"b/{path}", n=2, ) return "".join(diff) def _apply_unified_diff(broken_script: str, diff_text: str) -> str: # permissive applier (copied from forgeenv.env.diff_utils, simplified) diff_text = diff_text or "" if not diff_text.strip(): return broken_script lines = diff_text.lstrip().splitlines() if lines and not any(line.startswith(("---", "+++", "@@")) for line in lines[:5]): head = "\n".join(lines[:30]) markers = ("import ", "from ", "def ", "class ", "print(") if sum(1 for m in markers if m in head) >= 2: return diff_text # full replacement script # strict-ish apply: find each hunk old block and replace with new block src_lines = broken_script.splitlines(keepends=True) out: list[str] = [] diff_lines = diff_text.splitlines() i = 0 src_idx = 0 in_hunk = False hunk_old: list[str] = [] hunk_new: list[str] = [] def _flush() -> bool: nonlocal src_idx, hunk_old, hunk_new if not hunk_old and not hunk_new: return True target = "".join(hunk_old) remainder = "".join(src_lines[src_idx:]) pos = remainder.find(target) if pos == -1: return False out.append(remainder[:pos]) out.append("".join(hunk_new)) consumed = remainder[: pos + len(target)] src_idx += len(consumed.splitlines(keepends=True)) hunk_old, hunk_new = [], [] return True while i < len(diff_lines): line = diff_lines[i] if line.startswith(("---", "+++")): i += 1 continue if line.startswith("@@"): if in_hunk and not _flush(): break in_hunk = True i += 1 continue if in_hunk: if line.startswith("+"): hunk_new.append(line[1:] + "\n") elif line.startswith("-"): hunk_old.append(line[1:] + "\n") else: ctx = line[1:] if line.startswith(" ") else line hunk_old.append(ctx + "\n") hunk_new.append(ctx + "\n") i += 1 if in_hunk: _flush() out.append("".join(src_lines[src_idx:])) candidate = "".join(out) # fallback: (-,+) line-pair replacement if candidate.strip() == broken_script.strip(): repaired = broken_script pending_minus: str | None = None for line in diff_lines: if line.startswith(("---", "+++", "@@")): pending_minus = None continue if line.startswith("-"): pending_minus = line[1:].strip() elif line.startswith("+") and pending_minus is not None: old, new = pending_minus, line[1:].strip() if old and old in repaired: repaired = repaired.replace(old, new, 1) pending_minus = None return repaired return candidate # ----------------------------------------------------------------- model io def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool: """Cheap pre-check: pull adapter_config.json and compare base_model_name.""" try: from huggingface_hub import hf_hub_download cfg_path = hf_hub_download( repo_id=adapter_repo, filename="adapter_config.json", token=os.environ.get("HF_TOKEN"), ) with open(cfg_path) as f: cfg = json.load(f) adapter_base = (cfg.get("base_model_name_or_path") or "").lower() # Match by family substring -- "qwen2.5-coder-7b" must be present in # the base name, otherwise the adapter targets a different arch. family = base_name.split("/")[-1].lower().replace("-instruct", "") return family in adapter_base except Exception as e: # noqa: BLE001 print(f"[demo] adapter_config check failed ({e}); attempting load anyway") return True def _load_model() -> None: """Lazy-load the trained LoRA on first GPU invocation.""" global _model, _tokenizer, _load_error if not ENABLE_MODEL: _load_error = "Model disabled (ENABLE_MODEL=0). Using heuristic-only demo." return if _model is not None or _load_error is not None: return try: import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token # ZeroGPU sometimes initialises models on the "meta" device when using # `device_map="auto"`. Load onto a real device explicitly. base = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=dtype, device_map=None, low_cpu_mem_usage=False, ).to(device) if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL): try: model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device) print(f"[demo] LoRA attached: {ADAPTER_REPO}") except Exception as e: # noqa: BLE001 print(f"[demo] adapter load failed ({e}); using base model") model = base else: print( f"[demo] adapter at {ADAPTER_REPO} was trained on a different " f"base; using {BASE_MODEL} alone until matching adapter ships" ) model = base _model = model.eval() _tokenizer = tokenizer except Exception as e: # noqa: BLE001 _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" _SYSTEM_PROMPT = ( "You are an expert ML engineer who fixes broken HuggingFace training " "scripts caused by library version drift. Output ONLY a unified diff." ) def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str: """Greedy decode using the base model's chat template (Qwen ChatML).""" import torch messages = [ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": prompt}, ] try: text = _tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: # noqa: BLE001 text = prompt inputs = _tokenizer(text, return_tensors="pt").to(_model.device) with torch.no_grad(): out = _model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, repetition_penalty=1.15, pad_token_id=_tokenizer.eos_token_id, ) completion = _tokenizer.decode( out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True ) return completion.strip() # -------------------------------------------------------- diff extraction _FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE) _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) def _extract_diff_block(raw: str) -> str: """Pull the *first* fenced diff out of the model's raw output.""" if not raw: return "" m = _FENCE_RE.search(raw) if m: return m.group(1).strip() # otherwise grab from the first '---' / '+++' / '@@' onwards for marker in ("--- ", "+++ ", "@@"): idx = raw.find(marker) if idx >= 0: return raw[idx:].strip() return "" def _diff_actually_changes_script(broken: str, diff_text: str) -> bool: """Try to apply the diff. Returns True iff the result differs from input.""" if not diff_text: return False try: repaired = _apply_unified_diff(broken, diff_text) return bool(repaired) and repaired.strip() != broken.strip() except Exception: # noqa: BLE001 return False def _canonicalise(broken: str, diff_text: str) -> str: """Apply diff -> rebuild a clean canonical unified diff.""" from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff repaired = apply_unified_diff(broken, diff_text) if not repaired or repaired.strip() == broken.strip(): return "" return make_unified_diff(broken, repaired) def _extract_model_reasoning(raw: str) -> str: """Pull the natural-language reasoning out of the model's output (if any).""" if not raw: return "" text = re.sub(_FENCE_RE, "", raw).strip() text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip() lines = [ln.strip() for ln in text.splitlines() if ln.strip()] sentences: list[str] = [] for ln in lines: if ln.startswith(("---", "+++", "@@", "-", "+")): continue if len(ln) < 10: continue sentences.append(ln) if len(sentences) >= 3: break return " ".join(sentences) # ---------------------------------------------------- error-trace heuristic _DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE) _NO_ATTR_RE = re.compile( r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE ) _NO_MODULE_RE = re.compile( r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE ) _BAD_KWARG_RE = re.compile( r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE ) _USE_INSTEAD_RE = re.compile( r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE ) def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]: """Produce a (repaired_script, fix_description) pair from the trace. Patterns covered: * AttributeError + "Did you mean: 'X'?" -> rename method * AttributeError without hint -> remove the call (rarely useful) * ModuleNotFoundError 'X.Y' -> drop the .Y submodule * TypeError unexpected kwarg + 'use Y' -> swap kwarg * TypeError unexpected kwarg, no hint -> drop the kwarg """ if not error_trace: return broken, "" trace = error_trace.strip() repaired = broken description = "" # 1. AttributeError 'X' + Did you mean 'Y' if "AttributeError" in trace or "has no attribute" in trace: old = _NO_ATTR_RE.search(trace) new = _DID_YOU_MEAN_RE.search(trace) if old and new and old.group(1) != new.group(1): old_name, new_name = old.group(1), new.group(1) pattern = re.compile(rf"\b{re.escape(old_name)}\b") if pattern.search(repaired): repaired = pattern.sub(new_name, repaired) description = ( f"`{old_name}` is no longer an attribute on this object; " f"renamed call to `{new_name}` per the traceback hint." ) # 2. ModuleNotFoundError 'X.Y' (or 'X') if not description and "No module named" in trace: m = _NO_MODULE_RE.search(trace) if m: mod = m.group(1) if "." in mod: parent, child = mod.rsplit(".", 1) pat_full = re.compile(rf"\b{re.escape(mod)}\b") if pat_full.search(repaired): repaired = pat_full.sub(parent, repaired) description = ( f"`{mod}` was removed; replaced with parent module " f"`{parent}`." ) # 3. TypeError unexpected kwarg if not description and "unexpected keyword argument" in trace: bad = _BAD_KWARG_RE.search(trace) good = _USE_INSTEAD_RE.search(trace) if bad: bad_kw = bad.group(1) if good: good_kw = good.group(1) pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=") if pat.search(repaired): repaired = pat.sub(f"{good_kw}=", repaired) # if old kwarg was a boolean-ish, also swap the value # (pad_to_max_length=True -> padding=True is fine) description = ( f"`{bad_kw}` was renamed to `{good_kw}`; updated " f"keyword to match the new API." ) else: # remove the kwarg entirely (best-effort) pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+") if pat.search(repaired): repaired = pat.sub("", repaired) description = ( f"`{bad_kw}` is no longer accepted; removed the " f"keyword argument." ) return repaired, description # ------------------------------------------------------------- entry point try: import spaces # type: ignore _gpu_decorator = spaces.GPU(duration=60) except Exception: # noqa: BLE001 def _gpu_decorator(fn): return fn @_gpu_decorator def repair_script(script: str, error_trace: str) -> str: if not script.strip(): return "# Paste a broken script first." # Tier 1: error-trace heuristic (fast + reliable; keeps Space memory low) repaired, description = _heuristic_repair(script, error_trace) if description and repaired != script: diff = _make_unified_diff(script, repaired) header_lines = [ "# Source: error-trace heuristic (deterministic repair).", f"# Fix: {description}", ] return "\n".join(header_lines) + "\n\n" + diff # Tier 2: trained LoRA (optional; disabled by default on 16Gi Spaces) model_raw = "" model_diff_canonical = "" model_reasoning = "" if ENABLE_MODEL: _load_model() if _model is not None and _tokenizer is not None: try: versions = json.dumps( {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"} ) prompt = _PROMPT_TEMPLATE.format( versions=versions, script=script, trace=error_trace or "(no trace)", ) model_raw = _generate_with_model(prompt) model_diff_text = _extract_diff_block(model_raw) if _diff_actually_changes_script(script, model_diff_text): model_diff_canonical = _canonicalise(script, model_diff_text) model_reasoning = _extract_model_reasoning(model_raw) except Exception as e: # noqa: BLE001 print(f"[demo] model generation failed: {e}") if model_diff_canonical: header = ( "# Source: trained LoRA (ForgeEnv GRPO adapter)\n" "# The model produced a valid diff that successfully patches the script.\n" ) return header + "\n" + model_diff_canonical # Tier 3: nothing worked -- surface what we know msg_lines = ["# Could not produce a confident patch."] if _load_error and not ENABLE_MODEL: msg_lines.append(f"# Note: {_load_error}") if model_reasoning: msg_lines.append(f"# Trained model reasoning: {model_reasoning}") if error_trace: msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}") msg_lines.append( "# Try a more specific error trace (the heuristic looks for " "'Did you mean', 'No module named', or 'unexpected keyword argument')." ) return "\n".join(msg_lines) # ----------------------------------------------------------------- gradio with gr.Blocks(title="ForgeEnv Repair Agent") as demo: gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}") with gr.Row(): with gr.Column(): in_script = gr.Code( label="Broken HuggingFace script", language="python", lines=22, ) in_trace = gr.Textbox( label="Error trace", lines=6, placeholder="Traceback...", ) run_btn = gr.Button("Repair", variant="primary") with gr.Column(): out_diff = gr.Code( label="Suggested repair (unified diff)", language="markdown", lines=22, ) gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace]) run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff) if __name__ == "__main__": demo.launch()