"""Tests for the role helpers (drift generator + repair agent).""" import json from forgeenv.env.diff_utils import apply_unified_diff from forgeenv.primitives.breakage_primitives import RenameApiCall from forgeenv.roles.drift_generator import ( BaselineDriftGenerator, parse_drift_output, parse_drift_to_primitive, ) from forgeenv.roles.prompts import ( DRIFT_GENERATOR_SYSTEM_PROMPT, REPAIR_AGENT_SYSTEM_PROMPT, render_drift_generator_prompt, render_repair_agent_prompt, ) from forgeenv.roles.repair_agent import ( BaselineRepairAgent, extract_diff, looks_like_diff, ) # ------------------------------------------------------------------- prompts def test_prompts_are_nonempty(): assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT def test_render_drift_generator_prompt_includes_inputs(): text = render_drift_generator_prompt( "import torch", "RenameApiCall", {"transformers": "4.40.0"} ) assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text def test_render_repair_agent_prompt_includes_error_trace(): text = render_repair_agent_prompt( "broken", "AttributeError: foo", {"transformers": "4.50.0"} ) assert "AttributeError" in text and "transformers=4.50.0" in text # ------------------------------------------------------------ drift generator def test_parse_drift_output_handles_fences(): text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```" parsed = parse_drift_output(text) assert parsed is not None and parsed["primitive_type"] == "RenameApiCall" def test_parse_drift_output_handles_prose(): text = ( "Here is my breakage idea, it's a rename:\n" "{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n" "Hope this works!" ) parsed = parse_drift_output(text) assert parsed["primitive_type"] == "RenameApiCall" def test_parse_drift_output_returns_none_on_garbage(): assert parse_drift_output("no JSON here at all") is None assert parse_drift_output("") is None def test_parse_drift_to_primitive_validates(): text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}' primitive = parse_drift_to_primitive(text) assert primitive is not None and primitive.name == "DeprecateImport" def test_parse_drift_to_primitive_unknown_type(): text = '{"primitive_type": "NonExistent", "params": {}}' assert parse_drift_to_primitive(text) is None def test_baseline_drift_generator_produces_valid_spec(): gen = BaselineDriftGenerator(seed=0) script = """from transformers import Trainer trainer = Trainer() trainer.train() """ spec = gen.propose(target_category="RenameApiCall", script=script) assert spec["primitive_type"] in { "RenameApiCall", "DeprecateImport", "ChangeArgumentSignature", "ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior", "RemoveDeprecatedMethod", "ChangeReturnType", } assert isinstance(spec["params"], dict) def test_baseline_drift_generator_spec_actually_breaks_script(): gen = BaselineDriftGenerator(seed=42) script = """from transformers import Trainer trainer = Trainer() trainer.train() """ spec = gen.propose(target_category="RenameApiCall", script=script) primitive = parse_drift_to_primitive(json.dumps(spec)) broken = primitive.apply(script) # If we got a 'RenameApiCall' on trainer.train, it must have changed something. if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script: assert broken != script # -------------------------------------------------------------- repair agent def test_extract_diff_strips_fences(): text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n" diff = extract_diff(text) assert diff.startswith("---") and "foo" in diff and "bar" in diff def test_extract_diff_strips_chain_of_thought(): text = ( "Let me think... the error is X, so I should rename Y to Z.\n" "Here is the diff:\n" "--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n" ) diff = extract_diff(text) assert diff.startswith("---") assert "Let me think" not in diff def test_looks_like_diff_positive(): diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n" assert looks_like_diff(diff) def test_looks_like_diff_negative(): assert not looks_like_diff("just some prose without any diff structure") def test_baseline_repair_agent_oracle_path(): agent = BaselineRepairAgent() original = "import torch\nprint('hi')\n" broken = "import torch.legacy\nprint('hi')\n" diff = agent.repair(broken, breakage_spec=None, original_script=original) assert diff and "torch.legacy" in diff repaired = apply_unified_diff(broken, diff) assert repaired == original def test_baseline_repair_agent_inverts_breakage_spec(): agent = BaselineRepairAgent() original = "from transformers import Trainer\ntrainer.train()\n" breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training") broken = breakage.apply(original) spec = breakage.to_spec() diff = agent.repair(broken, breakage_spec=spec) assert diff repaired = apply_unified_diff(broken, diff) assert "trainer.train()" in repaired