| """Repair Agent helpers: response sanitisation + a deterministic baseline.
|
|
|
| The Repair Agent's training output is a unified diff. LLMs frequently emit
|
| prose / fences / chain-of-thought before the diff; this module strips that
|
| preamble. The baseline policy uses the inverse-primitive map from
|
| `repair_primitives.py` to produce ground-truth diffs for warm-start.
|
| """
|
| from __future__ import annotations
|
|
|
| import re
|
| from dataclasses import dataclass
|
| from typing import Optional
|
|
|
| from forgeenv.env.diff_utils import make_unified_diff
|
| from forgeenv.primitives.breakage_primitives import (
|
| parse_breakage_spec,
|
| BreakagePrimitive,
|
| )
|
| from forgeenv.primitives.repair_primitives import (
|
| BREAKAGE_TO_REPAIR,
|
| REPAIR_REGISTRY,
|
| RepairPrimitive,
|
| )
|
|
|
|
|
| _DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
|
| _FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
|
|
|
|
|
| def extract_diff(raw_text: str) -> str:
|
| """Pull the unified diff out of an LLM response.
|
|
|
| Handles: code fences, leading prose / chain-of-thought, trailing notes.
|
| """
|
| if not raw_text:
|
| return ""
|
| raw_text = raw_text.strip()
|
|
|
| fence_match = _FENCE_RE.search(raw_text)
|
| if fence_match:
|
| raw_text = fence_match.group(1).strip()
|
|
|
| lines = raw_text.splitlines()
|
| start = 0
|
| for i, line in enumerate(lines):
|
| if line.startswith(("---", "+++", "@@")):
|
| start = i
|
| break
|
|
|
| return "\n".join(lines[start:])
|
|
|
|
|
| def looks_like_diff(text: str) -> bool:
|
| if not text:
|
| return False
|
| has_header = "---" in text and "+++" in text
|
| has_hunk = bool(_DIFF_HUNK_RE.search(text))
|
| has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
|
| return (has_header and has_hunk) or (has_hunk and has_pm)
|
|
|
|
|
|
|
| @dataclass
|
| class BaselineRepairAgent:
|
| """Deterministic Repair Agent that uses the primitive inverse map.
|
|
|
| Used for warm-start dataset generation and baseline rollout comparisons.
|
| """
|
|
|
| def repair(
|
| self,
|
| broken_script: str,
|
| breakage_spec: Optional[dict] = None,
|
| original_script: str = "",
|
| ) -> str:
|
| """Return a unified diff (or full replacement script) that fixes the
|
| broken script.
|
|
|
| Strategy preference:
|
| 1. If `original_script` is provided, return a diff between the
|
| broken script and the original (oracle). This is the warm-start
|
| path — we always know the ground truth.
|
| 2. Otherwise try to invert the structured breakage_spec via the
|
| repair-primitive registry.
|
| 3. Otherwise return an empty diff.
|
| """
|
| if original_script and original_script != broken_script:
|
| return make_unified_diff(broken_script, original_script)
|
|
|
| if breakage_spec:
|
| try:
|
| breakage = parse_breakage_spec(breakage_spec)
|
| except (ValueError, TypeError):
|
| breakage = None
|
| if breakage is not None:
|
| repair = _invert_breakage(breakage)
|
| if repair is not None:
|
| repaired = repair.apply(broken_script)
|
| if repaired != broken_script:
|
| return make_unified_diff(broken_script, repaired)
|
|
|
| return ""
|
|
|
|
|
| _PARAM_REMAP: dict[str, dict[str, str]] = {
|
| "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
|
| "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
|
| "ChangeArgumentSignature": {
|
| "function_name": "function_name",
|
| "removed_arg": "arg_name",
|
| },
|
| "ModifyConfigField": {"field_name": "field_name"},
|
| "RestructureDatasetSchema": {
|
| "old_column": "old_column",
|
| "new_column": "new_column",
|
| },
|
| "ChangeTokenizerBehavior": {
|
| "old_kwarg": "old_kwarg",
|
| "old_value": "old_value",
|
| "new_kwarg": "new_kwarg",
|
| "new_value": "new_value",
|
| },
|
| "RemoveDeprecatedMethod": {"method_name": "method_name"},
|
| "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
|
| }
|
|
|
|
|
| def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
|
| breakage_name = type(breakage).__name__
|
| repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
|
| if repair_name is None:
|
| return None
|
| repair_cls = REPAIR_REGISTRY.get(repair_name)
|
| if repair_cls is None:
|
| return None
|
|
|
| breakage_params = breakage._get_params()
|
| remap = _PARAM_REMAP.get(breakage_name, {})
|
| mapped: dict[str, str] = {}
|
| for src_key, dst_key in remap.items():
|
| if src_key in breakage_params:
|
| mapped[dst_key] = breakage_params[src_key]
|
|
|
| valid_fields = {
|
| f.name
|
| for f in repair_cls.__dataclass_fields__.values()
|
| if f.init
|
| }
|
| filtered = {k: v for k, v in mapped.items() if k in valid_fields}
|
| try:
|
| return repair_cls(**filtered)
|
| except TypeError:
|
| return None
|
|
|