forgeenv-source / tests /test_roles.py
akhiilll's picture
forgeenv source snapshot for training job
a15535e verified
"""Tests for the role helpers (drift generator + repair agent)."""
import json
from forgeenv.env.diff_utils import apply_unified_diff
from forgeenv.primitives.breakage_primitives import RenameApiCall
from forgeenv.roles.drift_generator import (
BaselineDriftGenerator,
parse_drift_output,
parse_drift_to_primitive,
)
from forgeenv.roles.prompts import (
DRIFT_GENERATOR_SYSTEM_PROMPT,
REPAIR_AGENT_SYSTEM_PROMPT,
render_drift_generator_prompt,
render_repair_agent_prompt,
)
from forgeenv.roles.repair_agent import (
BaselineRepairAgent,
extract_diff,
looks_like_diff,
)
# ------------------------------------------------------------------- prompts
def test_prompts_are_nonempty():
assert "Drift Generator" in DRIFT_GENERATOR_SYSTEM_PROMPT
assert "Repair Agent" in REPAIR_AGENT_SYSTEM_PROMPT
def test_render_drift_generator_prompt_includes_inputs():
text = render_drift_generator_prompt(
"import torch", "RenameApiCall", {"transformers": "4.40.0"}
)
assert "RenameApiCall" in text and "transformers=4.40.0" in text and "import torch" in text
def test_render_repair_agent_prompt_includes_error_trace():
text = render_repair_agent_prompt(
"broken", "AttributeError: foo", {"transformers": "4.50.0"}
)
assert "AttributeError" in text and "transformers=4.50.0" in text
# ------------------------------------------------------------ drift generator
def test_parse_drift_output_handles_fences():
text = "```json\n{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"a\", \"new_name\": \"b\"}}\n```"
parsed = parse_drift_output(text)
assert parsed is not None and parsed["primitive_type"] == "RenameApiCall"
def test_parse_drift_output_handles_prose():
text = (
"Here is my breakage idea, it's a rename:\n"
"{\"primitive_type\": \"RenameApiCall\", \"params\": {\"old_name\": \"x\", \"new_name\": \"y\"}}\n"
"Hope this works!"
)
parsed = parse_drift_output(text)
assert parsed["primitive_type"] == "RenameApiCall"
def test_parse_drift_output_returns_none_on_garbage():
assert parse_drift_output("no JSON here at all") is None
assert parse_drift_output("") is None
def test_parse_drift_to_primitive_validates():
text = '{"primitive_type": "DeprecateImport", "params": {"old_module": "a", "new_module": "b"}}'
primitive = parse_drift_to_primitive(text)
assert primitive is not None and primitive.name == "DeprecateImport"
def test_parse_drift_to_primitive_unknown_type():
text = '{"primitive_type": "NonExistent", "params": {}}'
assert parse_drift_to_primitive(text) is None
def test_baseline_drift_generator_produces_valid_spec():
gen = BaselineDriftGenerator(seed=0)
script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
spec = gen.propose(target_category="RenameApiCall", script=script)
assert spec["primitive_type"] in {
"RenameApiCall", "DeprecateImport", "ChangeArgumentSignature",
"ModifyConfigField", "RestructureDatasetSchema", "ChangeTokenizerBehavior",
"RemoveDeprecatedMethod", "ChangeReturnType",
}
assert isinstance(spec["params"], dict)
def test_baseline_drift_generator_spec_actually_breaks_script():
gen = BaselineDriftGenerator(seed=42)
script = """from transformers import Trainer
trainer = Trainer()
trainer.train()
"""
spec = gen.propose(target_category="RenameApiCall", script=script)
primitive = parse_drift_to_primitive(json.dumps(spec))
broken = primitive.apply(script)
# If we got a 'RenameApiCall' on trainer.train, it must have changed something.
if spec["primitive_type"] == "RenameApiCall" and spec["params"].get("old_name") in script:
assert broken != script
# -------------------------------------------------------------- repair agent
def test_extract_diff_strips_fences():
text = "Here's my fix:\n```diff\n--- a/x\n+++ b/x\n@@\n-foo\n+bar\n```\n"
diff = extract_diff(text)
assert diff.startswith("---") and "foo" in diff and "bar" in diff
def test_extract_diff_strips_chain_of_thought():
text = (
"Let me think... the error is X, so I should rename Y to Z.\n"
"Here is the diff:\n"
"--- a/train.py\n+++ b/train.py\n@@ -1 +1 @@\n-import torch\n+import torch.legacy\n"
)
diff = extract_diff(text)
assert diff.startswith("---")
assert "Let me think" not in diff
def test_looks_like_diff_positive():
diff = "--- a/x\n+++ b/x\n@@ -1 +1 @@\n-foo\n+bar\n"
assert looks_like_diff(diff)
def test_looks_like_diff_negative():
assert not looks_like_diff("just some prose without any diff structure")
def test_baseline_repair_agent_oracle_path():
agent = BaselineRepairAgent()
original = "import torch\nprint('hi')\n"
broken = "import torch.legacy\nprint('hi')\n"
diff = agent.repair(broken, breakage_spec=None, original_script=original)
assert diff and "torch.legacy" in diff
repaired = apply_unified_diff(broken, diff)
assert repaired == original
def test_baseline_repair_agent_inverts_breakage_spec():
agent = BaselineRepairAgent()
original = "from transformers import Trainer\ntrainer.train()\n"
breakage = RenameApiCall(old_name="trainer.train", new_name="trainer.start_training")
broken = breakage.apply(original)
spec = breakage.to_spec()
diff = agent.repair(broken, breakage_spec=spec)
assert diff
repaired = apply_unified_diff(broken, diff)
assert "trainer.train()" in repaired