Spaces:
Sleeping
Sleeping
| """Gradio demo Space for the ForgeEnv Repair Agent. | |
| Three-tier repair pipeline so the demo always returns a useful diff: | |
| 1. **Trained LoRA model** — Qwen 2.5 + ForgeEnv GRPO adapter. If the model | |
| emits a diff that, when applied, actually changes the broken script, | |
| we use it. | |
| 2. **Error-trace heuristic** — extracts the fix signal from the Python | |
| traceback (Did you mean / unexpected kwarg / No module named) and | |
| emits a clean canonical diff. Handles the most common drift patterns. | |
| 3. **Model reasoning hint** — if heuristic fails, surface the model's | |
| natural-language reasoning (it usually explains the bug correctly even | |
| when its diff syntax is broken) alongside a "no patch produced" note. | |
| This separation means the demo is robust regardless of how well the | |
| LoRA generalises on a given input — and it's honest about what each | |
| component contributed. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import traceback | |
| from typing import Optional | |
| import gradio as gr | |
| BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-Coder-7B-Instruct") | |
| ADAPTER_REPO = os.environ.get("ADAPTER_REPO", "akhiilll/forgeenv-repair-agent") | |
| ENABLE_MODEL = os.environ.get("ENABLE_MODEL", "0").strip().lower() in {"1", "true", "yes", "y"} | |
| _TITLE = "ForgeEnv Repair Agent — fix HuggingFace scripts under library drift" | |
| _DESCRIPTION = ( | |
| "Paste a broken HuggingFace training script and the error trace it " | |
| "produced. The Repair Agent returns a minimal unified diff. The model " | |
| "was trained inside [ForgeEnv](https://huggingface.co/spaces/" | |
| "akhiilll/forgeenv) using GRPO (TRL + Unsloth) with R-Zero-style " | |
| "Challenger / Solver co-evolution. The agent is backed by a heuristic " | |
| "fallback that parses error traces directly when the LoRA's diff is " | |
| "malformed — keeps the demo robust on out-of-distribution inputs." | |
| ) | |
| _EXAMPLES = [ | |
| [ | |
| ( | |
| "from transformers import Trainer, TrainingArguments\n" | |
| "from datasets import load_dataset\n\n" | |
| "ds = load_dataset('glue', 'sst2')\n" | |
| "args = TrainingArguments(output_dir='out')\n" | |
| "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" | |
| "trainer.start_training()\n" | |
| ), | |
| ( | |
| "AttributeError: 'Trainer' object has no attribute 'start_training'. " | |
| "Did you mean: 'train'?" | |
| ), | |
| ], | |
| [ | |
| ( | |
| "import torch.legacy as torch\n" | |
| "x = torch.randn(2, 3)\n" | |
| "print(x)\n" | |
| ), | |
| "ModuleNotFoundError: No module named 'torch.legacy'", | |
| ], | |
| [ | |
| ( | |
| "from transformers import AutoTokenizer\n" | |
| "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" | |
| "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" | |
| "print(out)\n" | |
| ), | |
| ( | |
| "TypeError: __call__() got an unexpected keyword argument " | |
| "'pad_to_max_length' (use `padding=True` instead)." | |
| ), | |
| ], | |
| ] | |
| _PROMPT_TEMPLATE = ( | |
| "You are an expert ML engineer who fixes broken HuggingFace training " | |
| "scripts caused by library version drift.\n\n" | |
| "Library versions: {versions}\n\n" | |
| "Broken script:\n```python\n{script}\n```\n\n" | |
| "Error trace:\n```\n{trace}\n```\n\n" | |
| "Output ONLY a minimal unified diff (`--- a/script.py` / `+++ " | |
| "b/script.py` headers, then hunks). No prose." | |
| ) | |
| _model = None | |
| _tokenizer = None | |
| _load_error: Optional[str] = None | |
| # Minimal diff helpers (avoid requiring forgeenv in the Space runtime). | |
| def _make_unified_diff(before: str, after: str, path: str = "script.py") -> str: | |
| import difflib | |
| diff = difflib.unified_diff( | |
| before.splitlines(keepends=True), | |
| after.splitlines(keepends=True), | |
| fromfile=f"a/{path}", | |
| tofile=f"b/{path}", | |
| n=2, | |
| ) | |
| return "".join(diff) | |
| def _apply_unified_diff(broken_script: str, diff_text: str) -> str: | |
| # permissive applier (copied from forgeenv.env.diff_utils, simplified) | |
| diff_text = diff_text or "" | |
| if not diff_text.strip(): | |
| return broken_script | |
| lines = diff_text.lstrip().splitlines() | |
| if lines and not any(line.startswith(("---", "+++", "@@")) for line in lines[:5]): | |
| head = "\n".join(lines[:30]) | |
| markers = ("import ", "from ", "def ", "class ", "print(") | |
| if sum(1 for m in markers if m in head) >= 2: | |
| return diff_text # full replacement script | |
| # strict-ish apply: find each hunk old block and replace with new block | |
| src_lines = broken_script.splitlines(keepends=True) | |
| out: list[str] = [] | |
| diff_lines = diff_text.splitlines() | |
| i = 0 | |
| src_idx = 0 | |
| in_hunk = False | |
| hunk_old: list[str] = [] | |
| hunk_new: list[str] = [] | |
| def _flush() -> bool: | |
| nonlocal src_idx, hunk_old, hunk_new | |
| if not hunk_old and not hunk_new: | |
| return True | |
| target = "".join(hunk_old) | |
| remainder = "".join(src_lines[src_idx:]) | |
| pos = remainder.find(target) | |
| if pos == -1: | |
| return False | |
| out.append(remainder[:pos]) | |
| out.append("".join(hunk_new)) | |
| consumed = remainder[: pos + len(target)] | |
| src_idx += len(consumed.splitlines(keepends=True)) | |
| hunk_old, hunk_new = [], [] | |
| return True | |
| while i < len(diff_lines): | |
| line = diff_lines[i] | |
| if line.startswith(("---", "+++")): | |
| i += 1 | |
| continue | |
| if line.startswith("@@"): | |
| if in_hunk and not _flush(): | |
| break | |
| in_hunk = True | |
| i += 1 | |
| continue | |
| if in_hunk: | |
| if line.startswith("+"): | |
| hunk_new.append(line[1:] + "\n") | |
| elif line.startswith("-"): | |
| hunk_old.append(line[1:] + "\n") | |
| else: | |
| ctx = line[1:] if line.startswith(" ") else line | |
| hunk_old.append(ctx + "\n") | |
| hunk_new.append(ctx + "\n") | |
| i += 1 | |
| if in_hunk: | |
| _flush() | |
| out.append("".join(src_lines[src_idx:])) | |
| candidate = "".join(out) | |
| # fallback: (-,+) line-pair replacement | |
| if candidate.strip() == broken_script.strip(): | |
| repaired = broken_script | |
| pending_minus: str | None = None | |
| for line in diff_lines: | |
| if line.startswith(("---", "+++", "@@")): | |
| pending_minus = None | |
| continue | |
| if line.startswith("-"): | |
| pending_minus = line[1:].strip() | |
| elif line.startswith("+") and pending_minus is not None: | |
| old, new = pending_minus, line[1:].strip() | |
| if old and old in repaired: | |
| repaired = repaired.replace(old, new, 1) | |
| pending_minus = None | |
| return repaired | |
| return candidate | |
| # ----------------------------------------------------------------- model io | |
| def _adapter_compatible_with_base(adapter_repo: str, base_name: str) -> bool: | |
| """Cheap pre-check: pull adapter_config.json and compare base_model_name.""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| cfg_path = hf_hub_download( | |
| repo_id=adapter_repo, | |
| filename="adapter_config.json", | |
| token=os.environ.get("HF_TOKEN"), | |
| ) | |
| with open(cfg_path) as f: | |
| cfg = json.load(f) | |
| adapter_base = (cfg.get("base_model_name_or_path") or "").lower() | |
| # Match by family substring -- "qwen2.5-coder-7b" must be present in | |
| # the base name, otherwise the adapter targets a different arch. | |
| family = base_name.split("/")[-1].lower().replace("-instruct", "") | |
| return family in adapter_base | |
| except Exception as e: # noqa: BLE001 | |
| print(f"[demo] adapter_config check failed ({e}); attempting load anyway") | |
| return True | |
| def _load_model() -> None: | |
| """Lazy-load the trained LoRA on first GPU invocation.""" | |
| global _model, _tokenizer, _load_error | |
| if not ENABLE_MODEL: | |
| _load_error = "Model disabled (ENABLE_MODEL=0). Using heuristic-only demo." | |
| return | |
| if _model is not None or _load_error is not None: | |
| return | |
| try: | |
| import torch | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ZeroGPU sometimes initialises models on the "meta" device when using | |
| # `device_map="auto"`. Load onto a real device explicitly. | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=dtype, | |
| device_map=None, | |
| low_cpu_mem_usage=False, | |
| ).to(device) | |
| if _adapter_compatible_with_base(ADAPTER_REPO, BASE_MODEL): | |
| try: | |
| model = PeftModel.from_pretrained(base, ADAPTER_REPO).to(device) | |
| print(f"[demo] LoRA attached: {ADAPTER_REPO}") | |
| except Exception as e: # noqa: BLE001 | |
| print(f"[demo] adapter load failed ({e}); using base model") | |
| model = base | |
| else: | |
| print( | |
| f"[demo] adapter at {ADAPTER_REPO} was trained on a different " | |
| f"base; using {BASE_MODEL} alone until matching adapter ships" | |
| ) | |
| model = base | |
| _model = model.eval() | |
| _tokenizer = tokenizer | |
| except Exception as e: # noqa: BLE001 | |
| _load_error = f"{type(e).__name__}: {e}\n{traceback.format_exc()}" | |
| _SYSTEM_PROMPT = ( | |
| "You are an expert ML engineer who fixes broken HuggingFace training " | |
| "scripts caused by library version drift. Output ONLY a unified diff." | |
| ) | |
| def _generate_with_model(prompt: str, max_new_tokens: int = 384) -> str: | |
| """Greedy decode using the base model's chat template (Qwen ChatML).""" | |
| import torch | |
| messages = [ | |
| {"role": "system", "content": _SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| try: | |
| text = _tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: # noqa: BLE001 | |
| text = prompt | |
| inputs = _tokenizer(text, return_tensors="pt").to(_model.device) | |
| with torch.no_grad(): | |
| out = _model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| temperature=0.0, | |
| repetition_penalty=1.15, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| ) | |
| completion = _tokenizer.decode( | |
| out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True | |
| ) | |
| return completion.strip() | |
| # -------------------------------------------------------- diff extraction | |
| _FENCE_RE = re.compile(r"```(?:diff|patch)?\n([\s\S]*?)```", re.IGNORECASE) | |
| _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) | |
| def _extract_diff_block(raw: str) -> str: | |
| """Pull the *first* fenced diff out of the model's raw output.""" | |
| if not raw: | |
| return "" | |
| m = _FENCE_RE.search(raw) | |
| if m: | |
| return m.group(1).strip() | |
| # otherwise grab from the first '---' / '+++' / '@@' onwards | |
| for marker in ("--- ", "+++ ", "@@"): | |
| idx = raw.find(marker) | |
| if idx >= 0: | |
| return raw[idx:].strip() | |
| return "" | |
| def _diff_actually_changes_script(broken: str, diff_text: str) -> bool: | |
| """Try to apply the diff. Returns True iff the result differs from input.""" | |
| if not diff_text: | |
| return False | |
| try: | |
| repaired = _apply_unified_diff(broken, diff_text) | |
| return bool(repaired) and repaired.strip() != broken.strip() | |
| except Exception: # noqa: BLE001 | |
| return False | |
| def _canonicalise(broken: str, diff_text: str) -> str: | |
| """Apply diff -> rebuild a clean canonical unified diff.""" | |
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff | |
| repaired = apply_unified_diff(broken, diff_text) | |
| if not repaired or repaired.strip() == broken.strip(): | |
| return "" | |
| return make_unified_diff(broken, repaired) | |
| def _extract_model_reasoning(raw: str) -> str: | |
| """Pull the natural-language reasoning out of the model's output (if any).""" | |
| if not raw: | |
| return "" | |
| text = re.sub(_FENCE_RE, "", raw).strip() | |
| text = re.sub(r"^[\s\-+@]+", "", text, flags=re.MULTILINE).strip() | |
| lines = [ln.strip() for ln in text.splitlines() if ln.strip()] | |
| sentences: list[str] = [] | |
| for ln in lines: | |
| if ln.startswith(("---", "+++", "@@", "-", "+")): | |
| continue | |
| if len(ln) < 10: | |
| continue | |
| sentences.append(ln) | |
| if len(sentences) >= 3: | |
| break | |
| return " ".join(sentences) | |
| # ---------------------------------------------------- error-trace heuristic | |
| _DID_YOU_MEAN_RE = re.compile(r"Did you mean[:\s]+['`\"]?(\w+)['`\"]?", re.IGNORECASE) | |
| _NO_ATTR_RE = re.compile( | |
| r"has no attribute ['`\"]?(\w+)['`\"]?", re.IGNORECASE | |
| ) | |
| _NO_MODULE_RE = re.compile( | |
| r"No module named ['`\"]([\w\.]+)['`\"]", re.IGNORECASE | |
| ) | |
| _BAD_KWARG_RE = re.compile( | |
| r"unexpected keyword argument ['`\"](\w+)['`\"]", re.IGNORECASE | |
| ) | |
| _USE_INSTEAD_RE = re.compile( | |
| r"use\s+[`'\"]*(\w+)[\w=`'\"\s.\-]*instead", re.IGNORECASE | |
| ) | |
| def _heuristic_repair(broken: str, error_trace: str) -> tuple[str, str]: | |
| """Produce a (repaired_script, fix_description) pair from the trace. | |
| Patterns covered: | |
| * AttributeError + "Did you mean: 'X'?" -> rename method | |
| * AttributeError without hint -> remove the call (rarely useful) | |
| * ModuleNotFoundError 'X.Y' -> drop the .Y submodule | |
| * TypeError unexpected kwarg + 'use Y' -> swap kwarg | |
| * TypeError unexpected kwarg, no hint -> drop the kwarg | |
| """ | |
| if not error_trace: | |
| return broken, "" | |
| trace = error_trace.strip() | |
| repaired = broken | |
| description = "" | |
| # 1. AttributeError 'X' + Did you mean 'Y' | |
| if "AttributeError" in trace or "has no attribute" in trace: | |
| old = _NO_ATTR_RE.search(trace) | |
| new = _DID_YOU_MEAN_RE.search(trace) | |
| if old and new and old.group(1) != new.group(1): | |
| old_name, new_name = old.group(1), new.group(1) | |
| pattern = re.compile(rf"\b{re.escape(old_name)}\b") | |
| if pattern.search(repaired): | |
| repaired = pattern.sub(new_name, repaired) | |
| description = ( | |
| f"`{old_name}` is no longer an attribute on this object; " | |
| f"renamed call to `{new_name}` per the traceback hint." | |
| ) | |
| # 2. ModuleNotFoundError 'X.Y' (or 'X') | |
| if not description and "No module named" in trace: | |
| m = _NO_MODULE_RE.search(trace) | |
| if m: | |
| mod = m.group(1) | |
| if "." in mod: | |
| parent, child = mod.rsplit(".", 1) | |
| pat_full = re.compile(rf"\b{re.escape(mod)}\b") | |
| if pat_full.search(repaired): | |
| repaired = pat_full.sub(parent, repaired) | |
| description = ( | |
| f"`{mod}` was removed; replaced with parent module " | |
| f"`{parent}`." | |
| ) | |
| # 3. TypeError unexpected kwarg | |
| if not description and "unexpected keyword argument" in trace: | |
| bad = _BAD_KWARG_RE.search(trace) | |
| good = _USE_INSTEAD_RE.search(trace) | |
| if bad: | |
| bad_kw = bad.group(1) | |
| if good: | |
| good_kw = good.group(1) | |
| pat = re.compile(rf"\b{re.escape(bad_kw)}\s*=") | |
| if pat.search(repaired): | |
| repaired = pat.sub(f"{good_kw}=", repaired) | |
| # if old kwarg was a boolean-ish, also swap the value | |
| # (pad_to_max_length=True -> padding=True is fine) | |
| description = ( | |
| f"`{bad_kw}` was renamed to `{good_kw}`; updated " | |
| f"keyword to match the new API." | |
| ) | |
| else: | |
| # remove the kwarg entirely (best-effort) | |
| pat = re.compile(rf",?\s*\b{re.escape(bad_kw)}\s*=\s*[^,)\n]+") | |
| if pat.search(repaired): | |
| repaired = pat.sub("", repaired) | |
| description = ( | |
| f"`{bad_kw}` is no longer accepted; removed the " | |
| f"keyword argument." | |
| ) | |
| return repaired, description | |
| # ------------------------------------------------------------- entry point | |
| try: | |
| import spaces # type: ignore | |
| _gpu_decorator = spaces.GPU(duration=60) | |
| except Exception: # noqa: BLE001 | |
| def _gpu_decorator(fn): | |
| return fn | |
| def repair_script(script: str, error_trace: str) -> str: | |
| if not script.strip(): | |
| return "# Paste a broken script first." | |
| # Tier 1: error-trace heuristic (fast + reliable; keeps Space memory low) | |
| repaired, description = _heuristic_repair(script, error_trace) | |
| if description and repaired != script: | |
| diff = _make_unified_diff(script, repaired) | |
| header_lines = [ | |
| "# Source: error-trace heuristic (deterministic repair).", | |
| f"# Fix: {description}", | |
| ] | |
| return "\n".join(header_lines) + "\n\n" + diff | |
| # Tier 2: trained LoRA (optional; disabled by default on 16Gi Spaces) | |
| model_raw = "" | |
| model_diff_canonical = "" | |
| model_reasoning = "" | |
| if ENABLE_MODEL: | |
| _load_model() | |
| if _model is not None and _tokenizer is not None: | |
| try: | |
| versions = json.dumps( | |
| {"transformers": "4.45.0", "datasets": "2.20.0", "torch": "2.4.0"} | |
| ) | |
| prompt = _PROMPT_TEMPLATE.format( | |
| versions=versions, | |
| script=script, | |
| trace=error_trace or "(no trace)", | |
| ) | |
| model_raw = _generate_with_model(prompt) | |
| model_diff_text = _extract_diff_block(model_raw) | |
| if _diff_actually_changes_script(script, model_diff_text): | |
| model_diff_canonical = _canonicalise(script, model_diff_text) | |
| model_reasoning = _extract_model_reasoning(model_raw) | |
| except Exception as e: # noqa: BLE001 | |
| print(f"[demo] model generation failed: {e}") | |
| if model_diff_canonical: | |
| header = ( | |
| "# Source: trained LoRA (ForgeEnv GRPO adapter)\n" | |
| "# The model produced a valid diff that successfully patches the script.\n" | |
| ) | |
| return header + "\n" + model_diff_canonical | |
| # Tier 3: nothing worked -- surface what we know | |
| msg_lines = ["# Could not produce a confident patch."] | |
| if _load_error and not ENABLE_MODEL: | |
| msg_lines.append(f"# Note: {_load_error}") | |
| if model_reasoning: | |
| msg_lines.append(f"# Trained model reasoning: {model_reasoning}") | |
| if error_trace: | |
| msg_lines.append(f"# Error trace summary: {error_trace.splitlines()[-1]}") | |
| msg_lines.append( | |
| "# Try a more specific error trace (the heuristic looks for " | |
| "'Did you mean', 'No module named', or 'unexpected keyword argument')." | |
| ) | |
| return "\n".join(msg_lines) | |
| # ----------------------------------------------------------------- gradio | |
| with gr.Blocks(title="ForgeEnv Repair Agent") as demo: | |
| gr.Markdown(f"# {_TITLE}\n\n{_DESCRIPTION}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| in_script = gr.Code( | |
| label="Broken HuggingFace script", | |
| language="python", | |
| lines=22, | |
| ) | |
| in_trace = gr.Textbox( | |
| label="Error trace", | |
| lines=6, | |
| placeholder="Traceback...", | |
| ) | |
| run_btn = gr.Button("Repair", variant="primary") | |
| with gr.Column(): | |
| out_diff = gr.Code( | |
| label="Suggested repair (unified diff)", | |
| language="markdown", | |
| lines=22, | |
| ) | |
| gr.Examples(examples=_EXAMPLES, inputs=[in_script, in_trace]) | |
| run_btn.click(repair_script, inputs=[in_script, in_trace], outputs=out_diff) | |
| if __name__ == "__main__": | |
| demo.launch() | |