| """Repair Agent helpers: response sanitisation + a deterministic baseline. |
| |
| The Repair Agent's training output is a unified diff. LLMs frequently emit |
| prose / fences / chain-of-thought before the diff; this module strips that |
| preamble. The baseline policy uses the inverse-primitive map from |
| `repair_primitives.py` to produce ground-truth diffs for warm-start. |
| """ |
| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| from forgeenv.env.diff_utils import make_unified_diff |
| from forgeenv.primitives.breakage_primitives import ( |
| parse_breakage_spec, |
| BreakagePrimitive, |
| ) |
| from forgeenv.primitives.repair_primitives import ( |
| BREAKAGE_TO_REPAIR, |
| REPAIR_REGISTRY, |
| RepairPrimitive, |
| ) |
|
|
|
|
| _DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE) |
| _FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```") |
|
|
|
|
| def extract_diff(raw_text: str) -> str: |
| """Pull the unified diff out of an LLM response. |
| |
| Handles: code fences, leading prose / chain-of-thought, trailing notes. |
| """ |
| if not raw_text: |
| return "" |
| raw_text = raw_text.strip() |
|
|
| fence_match = _FENCE_RE.search(raw_text) |
| if fence_match: |
| raw_text = fence_match.group(1).strip() |
|
|
| lines = raw_text.splitlines() |
| start = 0 |
| for i, line in enumerate(lines): |
| if line.startswith(("---", "+++", "@@")): |
| start = i |
| break |
|
|
| return "\n".join(lines[start:]) |
|
|
|
|
| def looks_like_diff(text: str) -> bool: |
| if not text: |
| return False |
| has_header = "---" in text and "+++" in text |
| has_hunk = bool(_DIFF_HUNK_RE.search(text)) |
| has_pm = any(line.startswith(("+", "-")) for line in text.splitlines()) |
| return (has_header and has_hunk) or (has_hunk and has_pm) |
|
|
|
|
| |
| @dataclass |
| class BaselineRepairAgent: |
| """Deterministic Repair Agent that uses the primitive inverse map. |
| |
| Used for warm-start dataset generation and baseline rollout comparisons. |
| """ |
|
|
| def repair( |
| self, |
| broken_script: str, |
| breakage_spec: Optional[dict] = None, |
| original_script: str = "", |
| ) -> str: |
| """Return a unified diff (or full replacement script) that fixes the |
| broken script. |
| |
| Strategy preference: |
| 1. If `original_script` is provided, return a diff between the |
| broken script and the original (oracle). This is the warm-start |
| path — we always know the ground truth. |
| 2. Otherwise try to invert the structured breakage_spec via the |
| repair-primitive registry. |
| 3. Otherwise return an empty diff. |
| """ |
| if original_script and original_script != broken_script: |
| return make_unified_diff(broken_script, original_script) |
|
|
| if breakage_spec: |
| try: |
| breakage = parse_breakage_spec(breakage_spec) |
| except (ValueError, TypeError): |
| breakage = None |
| if breakage is not None: |
| repair = _invert_breakage(breakage) |
| if repair is not None: |
| repaired = repair.apply(broken_script) |
| if repaired != broken_script: |
| return make_unified_diff(broken_script, repaired) |
|
|
| return "" |
|
|
|
|
| _PARAM_REMAP: dict[str, dict[str, str]] = { |
| "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"}, |
| "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"}, |
| "ChangeArgumentSignature": { |
| "function_name": "function_name", |
| "removed_arg": "arg_name", |
| }, |
| "ModifyConfigField": {"field_name": "field_name"}, |
| "RestructureDatasetSchema": { |
| "old_column": "old_column", |
| "new_column": "new_column", |
| }, |
| "ChangeTokenizerBehavior": { |
| "old_kwarg": "old_kwarg", |
| "old_value": "old_value", |
| "new_kwarg": "new_kwarg", |
| "new_value": "new_value", |
| }, |
| "RemoveDeprecatedMethod": {"method_name": "method_name"}, |
| "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"}, |
| } |
|
|
|
|
| def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]: |
| breakage_name = type(breakage).__name__ |
| repair_name = BREAKAGE_TO_REPAIR.get(breakage_name) |
| if repair_name is None: |
| return None |
| repair_cls = REPAIR_REGISTRY.get(repair_name) |
| if repair_cls is None: |
| return None |
|
|
| breakage_params = breakage._get_params() |
| remap = _PARAM_REMAP.get(breakage_name, {}) |
| mapped: dict[str, str] = {} |
| for src_key, dst_key in remap.items(): |
| if src_key in breakage_params: |
| mapped[dst_key] = breakage_params[src_key] |
|
|
| valid_fields = { |
| f.name |
| for f in repair_cls.__dataclass_fields__.values() |
| if f.init |
| } |
| filtered = {k: v for k, v in mapped.items() if k in valid_fields} |
| try: |
| return repair_cls(**filtered) |
| except TypeError: |
| return None |
|
|