File size: 5,070 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Repair Agent helpers: response sanitisation + a deterministic baseline.

The Repair Agent's training output is a unified diff. LLMs frequently emit
prose / fences / chain-of-thought before the diff; this module strips that
preamble. The baseline policy uses the inverse-primitive map from
`repair_primitives.py` to produce ground-truth diffs for warm-start.
"""
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Optional

from forgeenv.env.diff_utils import make_unified_diff
from forgeenv.primitives.breakage_primitives import (
    parse_breakage_spec,
    BreakagePrimitive,
)
from forgeenv.primitives.repair_primitives import (
    BREAKAGE_TO_REPAIR,
    REPAIR_REGISTRY,
    RepairPrimitive,
)


_DIFF_HUNK_RE = re.compile(r"^@@.*@@", re.MULTILINE)
_FENCE_RE = re.compile(r"```[a-zA-Z]*\n([\s\S]*?)\n```")


def extract_diff(raw_text: str) -> str:
    """Pull the unified diff out of an LLM response.

    Handles: code fences, leading prose / chain-of-thought, trailing notes.
    """
    if not raw_text:
        return ""
    raw_text = raw_text.strip()

    fence_match = _FENCE_RE.search(raw_text)
    if fence_match:
        raw_text = fence_match.group(1).strip()

    lines = raw_text.splitlines()
    start = 0
    for i, line in enumerate(lines):
        if line.startswith(("---", "+++", "@@")):
            start = i
            break

    return "\n".join(lines[start:])


def looks_like_diff(text: str) -> bool:
    if not text:
        return False
    has_header = "---" in text and "+++" in text
    has_hunk = bool(_DIFF_HUNK_RE.search(text))
    has_pm = any(line.startswith(("+", "-")) for line in text.splitlines())
    return (has_header and has_hunk) or (has_hunk and has_pm)


# ---------------------------------------------------------------- baselines
@dataclass
class BaselineRepairAgent:
    """Deterministic Repair Agent that uses the primitive inverse map.

    Used for warm-start dataset generation and baseline rollout comparisons.
    """

    def repair(
        self,
        broken_script: str,
        breakage_spec: Optional[dict] = None,
        original_script: str = "",
    ) -> str:
        """Return a unified diff (or full replacement script) that fixes the
        broken script.

        Strategy preference:
          1. If `original_script` is provided, return a diff between the
             broken script and the original (oracle). This is the warm-start
             path — we always know the ground truth.
          2. Otherwise try to invert the structured breakage_spec via the
             repair-primitive registry.
          3. Otherwise return an empty diff.
        """
        if original_script and original_script != broken_script:
            return make_unified_diff(broken_script, original_script)

        if breakage_spec:
            try:
                breakage = parse_breakage_spec(breakage_spec)
            except (ValueError, TypeError):
                breakage = None
            if breakage is not None:
                repair = _invert_breakage(breakage)
                if repair is not None:
                    repaired = repair.apply(broken_script)
                    if repaired != broken_script:
                        return make_unified_diff(broken_script, repaired)

        return ""


_PARAM_REMAP: dict[str, dict[str, str]] = {
    "RenameApiCall": {"old_name": "old_name", "new_name": "new_name"},
    "DeprecateImport": {"old_module": "old_module", "new_module": "new_module"},
    "ChangeArgumentSignature": {
        "function_name": "function_name",
        "removed_arg": "arg_name",
    },
    "ModifyConfigField": {"field_name": "field_name"},
    "RestructureDatasetSchema": {
        "old_column": "old_column",
        "new_column": "new_column",
    },
    "ChangeTokenizerBehavior": {
        "old_kwarg": "old_kwarg",
        "old_value": "old_value",
        "new_kwarg": "new_kwarg",
        "new_value": "new_value",
    },
    "RemoveDeprecatedMethod": {"method_name": "method_name"},
    "ChangeReturnType": {"old_access": "old_access", "new_access": "new_access"},
}


def _invert_breakage(breakage: BreakagePrimitive) -> Optional[RepairPrimitive]:
    breakage_name = type(breakage).__name__
    repair_name = BREAKAGE_TO_REPAIR.get(breakage_name)
    if repair_name is None:
        return None
    repair_cls = REPAIR_REGISTRY.get(repair_name)
    if repair_cls is None:
        return None

    breakage_params = breakage._get_params()  # type: ignore[attr-defined]
    remap = _PARAM_REMAP.get(breakage_name, {})
    mapped: dict[str, str] = {}
    for src_key, dst_key in remap.items():
        if src_key in breakage_params:
            mapped[dst_key] = breakage_params[src_key]

    valid_fields = {
        f.name
        for f in repair_cls.__dataclass_fields__.values()  # type: ignore[attr-defined]
        if f.init
    }
    filtered = {k: v for k, v in mapped.items() if k in valid_fields}
    try:
        return repair_cls(**filtered)
    except TypeError:
        return None