File size: 5,562 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """Tests for the role helpers (drift generator + repair agent)."""
import json
from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.primitives.breakage_primitives import RenameApiCall
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
parse_drift_to_primitive,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import (
BaselineRepairAgent,
extract_diff,
looks_like_diff,
)
# ------------------------------------------------------------------- prompts
def test_prompts_are_nonempty():
assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT
assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT
def test_render_drift_generator_prompt_includes_inputs():
text = render_drift_generator_prompt(
"import torch", "RenameApiCall", {"transformers": "4.40.0"}
)
assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text
def test_render_repair_agent_prompt_includes_error_trace():
text = render_repair_agent_prompt(
"broken", "AttributeError: foo", {"transformers": "4.50.0"}
)
assert "AttributeError" in text and "transformers=4.50.0" in text
# ------------------------------------------------------------ drift generator
def test_parse_drift_output_handles_fences():
text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```"
parsed = parse_drift_output(text)
assert parsed is not None and parsed["primitive_type"] == "RenameApiCall"
def test_parse_drift_output_handles_prose():
text = (
"Here is my breakage idea, it's a rename:\n"
"{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n"
"Hope this works!"
)
parsed = parse_drift_output(text)
assert parsed["primitive_type"] == "RenameApiCall"
def test_parse_drift_output_returns_none_on_garbage():
assert parse_drift_output("no JSON here at all") is None
assert parse_drift_output("") is None
def test_parse_drift_to_primitive_validates():
text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}'
primitive = parse_drift_to_primitive(text)
assert primitive is not None and primitive.name == "DeprecateImport"
def test_parse_drift_to_primitive_unknown_type():
text = '{"primitive_type": "NonExistent", "params": {}}'
assert parse_drift_to_primitive(text) is None
def test_baseline_drift_generator_produces_valid_spec():
gen = BaselineDriftGenerator(seed=0)
script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
spec = gen.propose(target_category="RenameApiCall", script=script)
assert spec["primitive_type"] in {
"RenameApiCall", "DeprecateImport", "ChangeArgumentSignature",
"ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior",
"RemoveDeprecatedMethod", "ChangeReturnType",
}
assert isinstance(spec["params"], dict)
def test_baseline_drift_generator_spec_actually_breaks_script():
gen = BaselineDriftGenerator(seed=42)
script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
spec = gen.propose(target_category="RenameApiCall", script=script)
primitive = parse_drift_to_primitive(json.dumps(spec))
broken = primitive.apply(script)
# If we got a 'RenameApiCall' on trainer.train, it must have changed something.
if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script:
assert broken != script
# -------------------------------------------------------------- repair agent
def test_extract_diff_strips_fences():
text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n"
diff = extract_diff(text)
assert diff.startswith("---") and "foo" in diff and "bar" in diff
def test_extract_diff_strips_chain_of_thought():
text = (
"Let me think... the error is X, so I should rename Y to Z.\n"
"Here is the diff:\n"
"--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n"
)
diff = extract_diff(text)
assert diff.startswith("---")
assert "Let me think" not in diff
def test_looks_like_diff_positive():
diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n"
assert looks_like_diff(diff)
def test_looks_like_diff_negative():
assert not looks_like_diff("just some prose without any diff structure")
def test_baseline_repair_agent_oracle_path():
agent = BaselineRepairAgent()
original = "import torch\nprint('hi')\n"
broken = "import torch.legacy\nprint('hi')\n"
diff = agent.repair(broken, breakage_spec=None, original_script=original)
assert diff and "torch.legacy" in diff
repaired = apply_unified_diff(broken, diff)
assert repaired == original
def test_baseline_repair_agent_inverts_breakage_spec():
agent = BaselineRepairAgent()
original = "from transformers import Trainer\ntrainer.train()\n"
breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
broken = breakage.apply(original)
spec = breakage.to_spec()
diff = agent.repair(broken, breakage_spec=spec)
assert diff
repaired = apply_unified_diff(broken, diff)
assert "trainer.train()" in repaired
|