forgeenv-source / forgeenv /env /diff_utils.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Unified-diff application utilities.
The Repair Agent submits a unified diff. We need a permissive applier
because LLM diffs are often malformed (wrong line numbers, missing
context, extra prose). We try the strict applier first, then fall
back to applying hunks via plain string replacement.
The agent may also submit a full Python script instead of a diff
(common when the model's diff format breaks). We detect this and
treat it as a complete replacement.
"""
from __future__ import annotations
import difflib
import re
_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
_SCRIPT_MARKERS = ("import ", "from ", "def ", "class ", "print(")
def looks_like_full_script(text: str) -> bool:
"""Heuristic: text is probably a full python script, not a diff."""
lines = text.lstrip().splitlines()
if not lines:
return False
has_diff_header = any(
line.startswith(("---", "+++", "@@")) for line in lines[:5]
)
if has_diff_header:
return False
# If we see two or more script-style markers in the first 30 lines,
# treat as a full replacement script.
head = "\n".join(lines[:30])
hits = sum(1 for marker in _SCRIPT_MARKERS if marker in head)
return hits >= 2
def _strict_apply(broken_script: str, diff_text: str) -> str | None:
"""Apply a unified diff strictly. Returns None on any failure."""
lines = broken_script.splitlines(keepends=True)
out: list[str] = []
diff_lines = diff_text.splitlines()
i = 0
src_idx = 0
in_hunk = False
hunk_old: list[str] = []
hunk_new: list[str] = []
while i < len(diff_lines):
line = diff_lines[i]
if line.startswith(("---", "+++")):
i += 1
continue
if line.startswith("@@"):
# Flush previous hunk
if in_hunk:
# Find the hunk_old block in the source starting at src_idx.
target = "".join(hunk_old)
source_remainder = "".join(lines[src_idx:])
pos = source_remainder.find(target)
if pos == -1:
return None
out.append(source_remainder[:pos])
out.append("".join(hunk_new))
src_idx += len(source_remainder[: pos + len(target)].splitlines(keepends=True))
hunk_old, hunk_new = [], []
in_hunk = True
i += 1
continue
if in_hunk:
if line.startswith("+"):
hunk_new.append(line[1:] + "\n")
elif line.startswith("-"):
hunk_old.append(line[1:] + "\n")
else:
# context line
ctx = line[1:] if line.startswith(" ") else line
hunk_old.append(ctx + "\n")
hunk_new.append(ctx + "\n")
i += 1
# Flush trailing hunk
if in_hunk and (hunk_old or hunk_new):
target = "".join(hunk_old)
source_remainder = "".join(lines[src_idx:])
pos = source_remainder.find(target)
if pos == -1:
return None
out.append(source_remainder[:pos])
out.append("".join(hunk_new))
consumed = source_remainder[: pos + len(target)]
src_idx += len(consumed.splitlines(keepends=True))
out.append("".join(lines[src_idx:]))
return "".join(out)
def _permissive_apply(broken_script: str, diff_text: str) -> str:
"""Apply a malformed diff by extracting (-,+) line pairs and doing
a tolerant search-and-replace.
"""
repaired = broken_script
pairs: list[tuple[str, str]] = []
lines = diff_text.splitlines()
pending_minus: str | None = None
for line in lines:
if line.startswith("---") or line.startswith("+++") or line.startswith("@@"):
pending_minus = None
continue
if line.startswith("-"):
pending_minus = line[1:].strip()
elif line.startswith("+") and pending_minus is not None:
pairs.append((pending_minus, line[1:].strip()))
pending_minus = None
elif pending_minus is not None and not line.startswith(" "):
# standalone deletion — skip in permissive mode (we can't
# reliably know what to delete without context)
pending_minus = None
for old, new in pairs:
if old and old in repaired:
repaired = repaired.replace(old, new, 1)
return repaired
def apply_unified_diff(broken_script: str, diff_text: str) -> str:
"""Try every strategy in order and return the first that produces a change.
Strategies:
1. If `diff_text` looks like a full script, return it directly.
2. Try strict diff application.
3. Fall back to permissive (-,+) line-pair replacement.
4. As last resort, return the broken script unchanged.
"""
diff_text = diff_text or ""
if not diff_text.strip():
return broken_script
if looks_like_full_script(diff_text):
return diff_text
if _HUNK_RE.search(diff_text) or "---" in diff_text or "+++" in diff_text:
strict = _strict_apply(broken_script, diff_text)
if strict is not None and strict != broken_script:
return strict
perm = _permissive_apply(broken_script, diff_text)
return perm
def make_unified_diff(before: str, after: str, path: str = "train.py") -> str:
"""Produce a canonical unified diff from before -> after."""
diff = difflib.unified_diff(
before.splitlines(keepends=True),
after.splitlines(keepends=True),
fromfile=f"a/{path}",
tofile=f"b/{path}",
n=2,
)
return "".join(diff)