| """End-to-end tests for the OpenEnv-wrapped ForgeEnvironment."""
|
| import pytest
|
|
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction
|
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff
|
| from forgeenv.env.forge_environment import ForgeEnvironment
|
| from forgeenv.env.observations import ForgeObservation
|
| from forgeenv.roles.teacher import Teacher
|
|
|
|
|
| def test_reset_returns_drift_gen_observation():
|
| env = ForgeEnvironment(seed=0)
|
| obs = env.reset()
|
|
|
| assert isinstance(obs, ForgeObservation)
|
| assert obs.current_phase == "drift_gen"
|
| assert obs.done is False
|
| assert obs.script_content
|
| from forgeenv.primitives.breakage_primitives import PRIMITIVE_REGISTRY
|
|
|
| assert obs.target_category in PRIMITIVE_REGISTRY
|
| assert obs.library_versions
|
|
|
|
|
| def test_full_episode_lifecycle():
|
| env = ForgeEnvironment(seed=0)
|
| obs = env.reset()
|
| initial_script = obs.script_content
|
|
|
| breakage = ForgeAction(
|
| breakage=BreakageAction(
|
| primitive_type="DeprecateImport",
|
| params={"old_module": "import torch", "new_module": "import torch.legacy"},
|
| )
|
| )
|
| obs2 = env.step(breakage)
|
| assert obs2.current_phase == "repair"
|
| assert obs2.done is False
|
| assert obs2.info.get("breakage_primitive") == "DeprecateImport"
|
| assert obs2.error_trace is not None
|
|
|
| repair = ForgeAction(repair=RepairAction(unified_diff=initial_script))
|
| obs3 = env.step(repair)
|
| assert obs3.current_phase == "done"
|
| assert obs3.done is True
|
| assert obs3.reward is not None
|
| assert isinstance(obs3.reward_breakdown, dict)
|
| assert isinstance(obs3.held_out_breakdown, dict)
|
| assert {"executed_cleanly", "checkpoint_valid"} <= set(obs3.held_out_breakdown)
|
|
|
|
|
| def test_invalid_action_for_phase():
|
| env = ForgeEnvironment(seed=0)
|
| env.reset()
|
| repair_first = ForgeAction(repair=RepairAction(unified_diff="print('hi')"))
|
| obs = env.step(repair_first)
|
|
|
| assert obs.done is True
|
| assert obs.info.get("error") is not None
|
|
|
|
|
| def test_step_before_reset_returns_error():
|
| env = ForgeEnvironment(seed=0)
|
| breakage = ForgeAction(
|
| breakage=BreakageAction(primitive_type="RenameApiCall", params={})
|
| )
|
| obs = env.step(breakage)
|
| assert obs.done is True
|
| assert obs.info.get("error")
|
|
|
|
|
| def test_state_property_is_dict():
|
| env = ForgeEnvironment(seed=0)
|
| env.reset()
|
| state = env.state
|
| assert isinstance(state, dict)
|
| assert "phase" in state and "library_versions" in state and "teacher" in state
|
|
|
|
|
| def test_action_validation_rejects_both_or_neither():
|
| with pytest.raises(Exception):
|
| ForgeAction()
|
| with pytest.raises(Exception):
|
| ForgeAction(
|
| breakage=BreakageAction(primitive_type="RenameApiCall", params={}),
|
| repair=RepairAction(unified_diff="x"),
|
| )
|
|
|
|
|
| def test_teacher_updates_after_episode():
|
| teacher = Teacher(categories=["RenameApiCall"])
|
| env = ForgeEnvironment(teacher=teacher, seed=0)
|
| env.reset()
|
| env.step(
|
| ForgeAction(
|
| breakage=BreakageAction(
|
| primitive_type="RenameApiCall",
|
| params={"old_name": "x", "new_name": "y"},
|
| )
|
| )
|
| )
|
| env.step(ForgeAction(repair=RepairAction(unified_diff="print('noop')")))
|
| state = teacher.get_state()
|
| assert any(s["attempts"] >= 1 for s in state.values())
|
|
|
|
|
| def test_unified_diff_round_trip():
|
| before = "hello\nworld\n"
|
| after = "hello\nplanet\n"
|
| diff = make_unified_diff(before, after)
|
| repaired = apply_unified_diff(before, diff)
|
| assert repaired == after
|
|
|
|
|
| def test_unified_diff_full_script_replacement():
|
| full_script = """import torch
|
| from transformers import Trainer
|
| trainer = Trainer()
|
| trainer.train()
|
| """
|
| repaired = apply_unified_diff("broken stuff", full_script)
|
| assert repaired == full_script
|
|
|