File size: 5,070 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Repair Agent helpers: response sanitisation + a deterministic baseline.
The Repair Agent's training output is a unified diff. LLMs frequently emit
prose / fences / chain-of-thought before the diff; this module strips that
preamble. The baseline policy uses the inverse-primitive map from
`repair_primitives.py` to produce ground-truth diffs for warm-start.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
parse_breakage_spec,
BreakagePrimitive,
)
from forgeenv.primitives.repair_primitives import (
BREAKAGE_TO_REPAIR,
REPAIR_REGISTRY,
RepairPrimitive,
)
_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")
def extract_diff(raw_text: str) -> str:
"""Pull the unified diff out of an LLM response.
Handles: code fences, leading prose / chain-of-thought, trailing notes.
"""
if not raw_text:
return ""
raw_text = raw_text.strip()
fence_match = _FENCE_RE.search(raw_text)
if fence_match:
raw_text = fence_match.group(1).strip()
lines = raw_text.splitlines()
start = 0
for i, line in enumerate(lines):
if line.startswith(("---", "+++", "@@")):
start = i
break
return "\n".join(lines[start:])
def looks_like_diff(text: str) -> bool:
if not text:
return False
has_header = "---" in text and "+++" in text
has_hunk = bool(_DIFF_HUNK_RE.search(text))
has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
return (has_header and has_hunk) or (has_hunk and has_pm)
# ---------------------------------------------------------------- baselines
@dataclass
class BaselineRepairAgent:
"""Deterministic Repair Agent that uses the primitive inverse map.
Used for warm-start dataset generation and baseline rollout comparisons.
"""
def repair(
self,
broken_script: str,
breakage_spec: Optional[dict] = None,
original_script: str = "",
) -> str:
"""Return a unified diff (or full replacement script) that fixes the
broken script.
Strategy preference:
1. If `original_script` is provided, return a diff between the
broken script and the original (oracle). This is the warm-start
path — we always know the ground truth.
2. Otherwise try to invert the structured breakage_spec via the
repair-primitive registry.
3. Otherwise return an empty diff.
"""
if original_script and original_script != broken_script:
return make_unified_diff(broken_script, original_script)
if breakage_spec:
try:
breakage = parse_breakage_spec(breakage_spec)
except (ValueError, TypeError):
breakage = None
if breakage is not None:
repair = _invert_breakage(breakage)
if repair is not None:
repaired = repair.apply(broken_script)
if repaired != broken_script:
return make_unified_diff(broken_script, repaired)
return ""
_PARAM_REMAP: dict[str, dict[str, str]] = {
"RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
"DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
"ChangeArgumentSignature": {
"function_name": "function_name",
"removed_arg": "arg_name",
},
"ModifyConfigField": {"field_name": "field_name"},
"RestructureDatasetSchema": {
"old_column": "old_column",
"new_column": "new_column",
},
"ChangeTokenizerBehavior": {
"old_kwarg": "old_kwarg",
"old_value": "old_value",
"new_kwarg": "new_kwarg",
"new_value": "new_value",
},
"RemoveDeprecatedMethod": {"method_name": "method_name"},
"ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
}
def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
breakage_name = type(breakage).__name__
repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
if repair_name is None:
return None
repair_cls = REPAIR_REGISTRY.get(repair_name)
if repair_cls is None:
return None
breakage_params = breakage._get_params() # type: ignore[attr-defined]
remap = _PARAM_REMAP.get(breakage_name, {})
mapped: dict[str, str] = {}
for src_key, dst_key in remap.items():
if src_key in breakage_params:
mapped[dst_key] = breakage_params[src_key]
valid_fields = {
f.name
for f in repair_cls.__dataclass_fields__.values() # type: ignore[attr-defined]
if f.init
}
filtered = {k: v for k, v in mapped.items() if k in valid_fields}
try:
return repair_cls(**filtered)
except TypeError:
return None
|