"""Repair Agent helpers: response sanitisation + a deterministic baseline. The Repair Agent's training output is a unified diff. LLMs frequently emit prose / fences / chain-of-thought before the diff; this module strips that preamble. The baseline policy uses the inverse-primitive map from `repair_primitives.py` to produce ground-truth diffs for warm-start. """ from __future__ import annotations import re from dataclasses import dataclass from typing import Optional from forgeenv.env.diff_utils import make_unified_diff from forgeenv.primitives.breakage_primitives import ( parse_breakage_spec, BreakagePrimitive, ) from forgeenv.primitives.repair_primitives import ( BREAKAGE_TO_REPAIR, REPAIR_REGISTRY, RepairPrimitive, ) _DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) _FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```") def extract_diff(raw_text: str) -> str: """Pull the unified diff out of an LLM response. Handles: code fences, leading prose / chain-of-thought, trailing notes. """ if not raw_text: return "" raw_text = raw_text.strip() fence_match = _FENCE_RE.search(raw_text) if fence_match: raw_text = fence_match.group(1).strip() lines = raw_text.splitlines() start = 0 for i, line in enumerate(lines): if line.startswith(("---", "+++", "@@")): start = i break return "\n".join(lines[start:]) def looks_like_diff(text: str) -> bool: if not text: return False has_header = "---" in text and "+++" in text has_hunk = bool(_DIFF_HUNK_RE.search(text)) has_pm = any(line.startswith(("+", "-")) for line in text.splitlines()) return (has_header and has_hunk) or (has_hunk and has_pm) # ---------------------------------------------------------------- baselines @dataclass class BaselineRepairAgent: """Deterministic Repair Agent that uses the primitive inverse map. Used for warm-start dataset generation and baseline rollout comparisons. """ def repair( self, broken_script: str, breakage_spec: Optional[dict] = None, original_script: str = "", ) -> str: """Return a unified diff (or full replacement script) that fixes the broken script. Strategy preference: 1. If `original_script` is provided, return a diff between the broken script and the original (oracle). This is the warm-start path — we always know the ground truth. 2. Otherwise try to invert the structured breakage_spec via the repair-primitive registry. 3. Otherwise return an empty diff. """ if original_script and original_script != broken_script: return make_unified_diff(broken_script, original_script) if breakage_spec: try: breakage = parse_breakage_spec(breakage_spec) except (ValueError, TypeError): breakage = None if breakage is not None: repair = _invert_breakage(breakage) if repair is not None: repaired = repair.apply(broken_script) if repaired != broken_script: return make_unified_diff(broken_script, repaired) return "" _PARAM_REMAP: dict[str, dict[str, str]] = { "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"}, "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"}, "ChangeArgumentSignature": { "function_name": "function_name", "removed_arg": "arg_name", }, "ModifyConfigField": {"field_name": "field_name"}, "RestructureDatasetSchema": { "old_column": "old_column", "new_column": "new_column", }, "ChangeTokenizerBehavior": { "old_kwarg": "old_kwarg", "old_value": "old_value", "new_kwarg": "new_kwarg", "new_value": "new_value", }, "RemoveDeprecatedMethod": {"method_name": "method_name"}, "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"}, } def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]: breakage_name = type(breakage).__name__ repair_name = BREAKAGE_TO_REPAIR.get(breakage_name) if repair_name is None: return None repair_cls = REPAIR_REGISTRY.get(repair_name) if repair_cls is None: return None breakage_params = breakage._get_params() # type: ignore[attr-defined] remap = _PARAM_REMAP.get(breakage_name, {}) mapped: dict[str, str] = {} for src_key, dst_key in remap.items(): if src_key in breakage_params: mapped[dst_key] = breakage_params[src_key] valid_fields = { f.name for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined] if f.init } filtered = {k: v for k, v in mapped.items() if k in valid_fields} try: return repair_cls(**filtered) except TypeError: return None