forgeenv-source / forgeenv /roles /repair_agent.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Repair Agent helpers: response sanitisation + a deterministic baseline.
The Repair Agent's training output is a unified diff. LLMs frequently emit
prose / fences / chain-of-thought before the diff; this module strips that
preamble. The baseline policy uses the inverse-primitive map from
`repair_primitives.py` to produce ground-truth diffs for warm-start.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
parse_breakage_spec,
BreakagePrimitive,
)
from forgeenv.primitives.repair_primitives import (
BREAKAGE_TO_REPAIR,
REPAIR_REGISTRY,
RepairPrimitive,
)
_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
def extract_diff(raw_text: str) -> str:
"""Pull the unified diff out of an LLM response.
Handles: code fences, leading prose / chain-of-thought, trailing notes.
"""
if not raw_text:
return ""
raw_text = raw_text.strip()
fence_match = _FENCE_RE.search(raw_text)
if fence_match:
raw_text = fence_match.group(1).strip()
lines = raw_text.splitlines()
start = 0
for i, line in enumerate(lines):
if line.startswith(("---", "+++", "@@")):
start = i
break
return "\n".join(lines[start:])
def looks_like_diff(text: str) -> bool:
if not text:
return False
has_header = "---" in text and "+++" in text
has_hunk = bool(_DIFF_HUNK_RE.search(text))
has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
return (has_header and has_hunk) or (has_hunk and has_pm)
# ---------------------------------------------------------------- baselines
@dataclass
class BaselineRepairAgent:
"""Deterministic Repair Agent that uses the primitive inverse map.
Used for warm-start dataset generation and baseline rollout comparisons.
"""
def repair(
self,
broken_script: str,
breakage_spec: Optional[dict] = None,
original_script: str = "",
) -> str:
"""Return a unified diff (or full replacement script) that fixes the
broken script.
Strategy preference:
1. If `original_script` is provided, return a diff between the
broken script and the original (oracle). This is the warm-start
path — we always know the ground truth.
2. Otherwise try to invert the structured breakage_spec via the
repair-primitive registry.
3. Otherwise return an empty diff.
"""
if original_script and original_script != broken_script:
return make_unified_diff(broken_script, original_script)
if breakage_spec:
try:
breakage = parse_breakage_spec(breakage_spec)
except (ValueError, TypeError):
breakage = None
if breakage is not None:
repair = _invert_breakage(breakage)
if repair is not None:
repaired = repair.apply(broken_script)
if repaired != broken_script:
return make_unified_diff(broken_script, repaired)
return ""
_PARAM_REMAP: dict[str, dict[str, str]] = {
"RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
"DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
"ChangeArgumentSignature": {
"function_name": "function_name",
"removed_arg": "arg_name",
},
"ModifyConfigField": {"field_name": "field_name"},
"RestructureDatasetSchema": {
"old_column": "old_column",
"new_column": "new_column",
},
"ChangeTokenizerBehavior": {
"old_kwarg": "old_kwarg",
"old_value": "old_value",
"new_kwarg": "new_kwarg",
"new_value": "new_value",
},
"RemoveDeprecatedMethod": {"method_name": "method_name"},
"ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
}
def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
breakage_name = type(breakage).__name__
repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
if repair_name is None:
return None
repair_cls = REPAIR_REGISTRY.get(repair_name)
if repair_cls is None:
return None
breakage_params = breakage._get_params() # type: ignore[attr-defined]
remap = _PARAM_REMAP.get(breakage_name, {})
mapped: dict[str, str] = {}
for src_key, dst_key in remap.items():
if src_key in breakage_params:
mapped[dst_key] = breakage_params[src_key]
valid_fields = {
f.name
for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
if f.init
}
filtered = {k: v for k, v in mapped.items() if k in valid_fields}
try:
return repair_cls(**filtered)
except TypeError:
return None