File size: 5,562 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Tests for the role helpers (drift generator + repair agent)."""
import json

from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.primitives.breakage_primitives import RenameApiCall
from forgeenv.roles.drift_generator import (
    BaselineDriftGenerator,
    parse_drift_output,
    parse_drift_to_primitive,
)
from forgeenv.roles.prompts import (
    DRIFT_GENERATOR_SYSTEM_PROMPT,
    REPAIR_AGENT_SYSTEM_PROMPT,
    render_drift_generator_prompt,
    render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import (
    BaselineRepairAgent,
    extract_diff,
    looks_like_diff,
)


# ------------------------------------------------------------------- prompts
def test_prompts_are_nonempty():
    assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT
    assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT


def test_render_drift_generator_prompt_includes_inputs():
    text = render_drift_generator_prompt(
        "import torch", "RenameApiCall", {"transformers": "4.40.0"}
    )
    assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text


def test_render_repair_agent_prompt_includes_error_trace():
    text = render_repair_agent_prompt(
        "broken", "AttributeError: foo", {"transformers": "4.50.0"}
    )
    assert "AttributeError" in text and "transformers=4.50.0" in text


# ------------------------------------------------------------ drift generator
def test_parse_drift_output_handles_fences():
    text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```"
    parsed = parse_drift_output(text)
    assert parsed is not None and parsed["primitive_type"] == "RenameApiCall"


def test_parse_drift_output_handles_prose():
    text = (
        "Here is my breakage idea, it's a rename:\n"
        "{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n"
        "Hope this works!"
    )
    parsed = parse_drift_output(text)
    assert parsed["primitive_type"] == "RenameApiCall"


def test_parse_drift_output_returns_none_on_garbage():
    assert parse_drift_output("no JSON here at all") is None
    assert parse_drift_output("") is None


def test_parse_drift_to_primitive_validates():
    text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}'
    primitive = parse_drift_to_primitive(text)
    assert primitive is not None and primitive.name == "DeprecateImport"


def test_parse_drift_to_primitive_unknown_type():
    text = '{"primitive_type": "NonExistent", "params": {}}'
    assert parse_drift_to_primitive(text) is None


def test_baseline_drift_generator_produces_valid_spec():
    gen = BaselineDriftGenerator(seed=0)
    script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
    spec = gen.propose(target_category="RenameApiCall", script=script)
    assert spec["primitive_type"] in {
        "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature",
        "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior",
        "RemoveDeprecatedMethod", "ChangeReturnType",
    }
    assert isinstance(spec["params"], dict)


def test_baseline_drift_generator_spec_actually_breaks_script():
    gen = BaselineDriftGenerator(seed=42)
    script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
    spec = gen.propose(target_category="RenameApiCall", script=script)
    primitive = parse_drift_to_primitive(json.dumps(spec))
    broken = primitive.apply(script)
    # If we got a 'RenameApiCall' on trainer.train, it must have changed something.
    if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script:
        assert broken != script


# -------------------------------------------------------------- repair agent
def test_extract_diff_strips_fences():
    text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n"
    diff = extract_diff(text)
    assert diff.startswith("---") and "foo" in diff and "bar" in diff


def test_extract_diff_strips_chain_of_thought():
    text = (
        "Let me think... the error is X, so I should rename Y to Z.\n"
        "Here is the diff:\n"
        "--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n"
    )
    diff = extract_diff(text)
    assert diff.startswith("---")
    assert "Let me think" not in diff


def test_looks_like_diff_positive():
    diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n"
    assert looks_like_diff(diff)


def test_looks_like_diff_negative():
    assert not looks_like_diff("just some prose without any diff structure")


def test_baseline_repair_agent_oracle_path():
    agent = BaselineRepairAgent()
    original = "import torch\nprint('hi')\n"
    broken = "import torch.legacy\nprint('hi')\n"
    diff = agent.repair(broken, breakage_spec=None, original_script=original)
    assert diff and "torch.legacy" in diff
    repaired = apply_unified_diff(broken, diff)
    assert repaired == original


def test_baseline_repair_agent_inverts_breakage_spec():
    agent = BaselineRepairAgent()
    original = "from transformers import Trainer\ntrainer.train()\n"
    breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
    broken = breakage.apply(original)
    spec = breakage.to_spec()

    diff = agent.repair(broken, breakage_spec=spec)
    assert diff
    repaired = apply_unified_diff(broken, diff)
    assert "trainer.train()" in repaired