| """Unified-diff application utilities.
|
|
|
| The Repair Agent submits a unified diff. We need a permissive applier
|
| because LLM diffs are often malformed (wrong line numbers, missing
|
| context, extra prose). We try the strict applier first, then fall
|
| back to applying hunks via plain string replacement.
|
|
|
| The agent may also submit a full Python script instead of a diff
|
| (common when the model's diff format breaks). We detect this and
|
| treat it as a complete replacement.
|
| """
|
| from __future__ import annotations
|
|
|
| import difflib
|
| import re
|
|
|
|
|
| _HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| _SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
|
|
|
|
|
| def looks_like_full_script(text: str) -> bool:
|
| """Heuristic: text is probably a full python script, not a diff."""
|
| lines = text.lstrip().splitlines()
|
| if not lines:
|
| return False
|
| has_diff_header = any(
|
| line.startswith(("---", "+++", "@@")) for line in lines[:5]
|
| )
|
| if has_diff_header:
|
| return False
|
|
|
|
|
| head = "\n".join(lines[:30])
|
| hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
|
| return hits >= 2
|
|
|
|
|
| def _strict_apply(broken_script: str, diff_text: str) -> str | None:
|
| """Apply a unified diff strictly. Returns None on any failure."""
|
| lines = broken_script.splitlines(keepends=True)
|
| out: list[str] = []
|
| diff_lines = diff_text.splitlines()
|
| i = 0
|
| src_idx = 0
|
| in_hunk = False
|
| hunk_old: list[str] = []
|
| hunk_new: list[str] = []
|
|
|
| while i < len(diff_lines):
|
| line = diff_lines[i]
|
| if line.startswith(("---", "+++")):
|
| i += 1
|
| continue
|
| if line.startswith("@@"):
|
|
|
| if in_hunk:
|
|
|
| target = "".join(hunk_old)
|
| source_remainder = "".join(lines[src_idx:])
|
| pos = source_remainder.find(target)
|
| if pos == -1:
|
| return None
|
| out.append(source_remainder[:pos])
|
| out.append("".join(hunk_new))
|
| src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
|
| hunk_old, hunk_new = [], []
|
| in_hunk = True
|
| i += 1
|
| continue
|
| if in_hunk:
|
| if line.startswith("+"):
|
| hunk_new.append(line[1:] + "\n")
|
| elif line.startswith("-"):
|
| hunk_old.append(line[1:] + "\n")
|
| else:
|
|
|
| ctx = line[1:] if line.startswith(" ") else line
|
| hunk_old.append(ctx + "\n")
|
| hunk_new.append(ctx + "\n")
|
| i += 1
|
|
|
|
|
| if in_hunk and (hunk_old or hunk_new):
|
| target = "".join(hunk_old)
|
| source_remainder = "".join(lines[src_idx:])
|
| pos = source_remainder.find(target)
|
| if pos == -1:
|
| return None
|
| out.append(source_remainder[:pos])
|
| out.append("".join(hunk_new))
|
| consumed = source_remainder[: pos + len(target)]
|
| src_idx += len(consumed.splitlines(keepends=True))
|
|
|
| out.append("".join(lines[src_idx:]))
|
| return "".join(out)
|
|
|
|
|
| def _permissive_apply(broken_script: str, diff_text: str) -> str:
|
| """Apply a malformed diff by extracting (-,+) line pairs and doing
|
| a tolerant search-and-replace.
|
| """
|
| repaired = broken_script
|
| pairs: list[tuple[str, str]] = []
|
| lines = diff_text.splitlines()
|
| pending_minus: str | None = None
|
|
|
| for line in lines:
|
| if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
|
| pending_minus = None
|
| continue
|
| if line.startswith("-"):
|
| pending_minus = line[1:].strip()
|
| elif line.startswith("+") and pending_minus is not None:
|
| pairs.append((pending_minus, line[1:].strip()))
|
| pending_minus = None
|
| elif pending_minus is not None and not line.startswith(" "):
|
|
|
|
|
| pending_minus = None
|
|
|
| for old, new in pairs:
|
| if old and old in repaired:
|
| repaired = repaired.replace(old, new, 1)
|
|
|
| return repaired
|
|
|
|
|
| def apply_unified_diff(broken_script: str, diff_text: str) -> str:
|
| """Try every strategy in order and return the first that produces a change.
|
|
|
| Strategies:
|
| 1. If `diff_text` looks like a full script, return it directly.
|
| 2. Try strict diff application.
|
| 3. Fall back to permissive (-,+) line-pair replacement.
|
| 4. As last resort, return the broken script unchanged.
|
| """
|
| diff_text = diff_text or ""
|
| if not diff_text.strip():
|
| return broken_script
|
|
|
| if looks_like_full_script(diff_text):
|
| return diff_text
|
|
|
| if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
|
| strict = _strict_apply(broken_script, diff_text)
|
| if strict is not None and strict != broken_script:
|
| return strict
|
|
|
| perm = _permissive_apply(broken_script, diff_text)
|
| return perm
|
|
|
|
|
| def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
|
| """Produce a canonical unified diff from before -> after."""
|
| diff = difflib.unified_diff(
|
| before.splitlines(keepends=True),
|
| after.splitlines(keepends=True),
|
| fromfile=f"a/{path}",
|
| tofile=f"b/{path}",
|
| n=2,
|
| )
|
| return "".join(diff)
|
|
|