"""End-to-end tests for the OpenEnv-wrapped ForgeEnvironment.""" import pytest from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff from forgeenv.env.forge_environment import ForgeEnvironment from forgeenv.env.observations import ForgeObservation from forgeenv.roles.teacher import Teacher def test_reset_returns_drift_gen_observation(): env = ForgeEnvironment(seed=0) obs = env.reset() assert isinstance(obs, ForgeObservation) assert obs.current_phase == "drift_gen" assert obs.done is False assert obs.script_content from forgeenv.primitives.breakage_primitives import PRIMITIVE_REGISTRY assert obs.target_category in PRIMITIVE_REGISTRY assert obs.library_versions # non-empty dict def test_full_episode_lifecycle(): env = ForgeEnvironment(seed=0) obs = env.reset() initial_script = obs.script_content breakage = ForgeAction( breakage=BreakageAction( primitive_type="DeprecateImport", params={"old_module": "import torch", "new_module": "import torch.legacy"}, ) ) obs2 = env.step(breakage) assert obs2.current_phase == "repair" assert obs2.done is False assert obs2.info.get("breakage_primitive") == "DeprecateImport" assert obs2.error_trace is not None repair = ForgeAction(repair=RepairAction(unified_diff=initial_script)) obs3 = env.step(repair) assert obs3.current_phase == "done" assert obs3.done is True assert obs3.reward is not None assert isinstance(obs3.reward_breakdown, dict) assert isinstance(obs3.held_out_breakdown, dict) assert {"executed_cleanly", "checkpoint_valid"} <= set(obs3.held_out_breakdown) def test_invalid_action_for_phase(): env = ForgeEnvironment(seed=0) env.reset() repair_first = ForgeAction(repair=RepairAction(unified_diff="print('hi')")) obs = env.step(repair_first) # Should not raise — should return a done=True error observation. assert obs.done is True assert obs.info.get("error") is not None def test_step_before_reset_returns_error(): env = ForgeEnvironment(seed=0) breakage = ForgeAction( breakage=BreakageAction(primitive_type="RenameApiCall", params={}) ) obs = env.step(breakage) assert obs.done is True assert obs.info.get("error") def test_state_property_is_dict(): env = ForgeEnvironment(seed=0) env.reset() state = env.state assert isinstance(state, dict) assert "phase" in state and "library_versions" in state and "teacher" in state def test_action_validation_rejects_both_or_neither(): with pytest.raises(Exception): ForgeAction() with pytest.raises(Exception): ForgeAction( breakage=BreakageAction(primitive_type="RenameApiCall", params={}), repair=RepairAction(unified_diff="x"), ) def test_teacher_updates_after_episode(): teacher = Teacher(categories=["RenameApiCall"]) env = ForgeEnvironment(teacher=teacher, seed=0) env.reset() env.step( ForgeAction( breakage=BreakageAction( primitive_type="RenameApiCall", params={"old_name": "x", "new_name": "y"}, ) ) ) env.step(ForgeAction(repair=RepairAction(unified_diff="print('noop')"))) state = teacher.get_state() assert any(s["attempts"] >= 1 for s in state.values()) def test_unified_diff_round_trip(): before = "hello\nworld\n" after = "hello\nplanet\n" diff = make_unified_diff(before, after) repaired = apply_unified_diff(before, diff) assert repaired == after def test_unified_diff_full_script_replacement(): full_script = """import torch from transformers import Trainer trainer = Trainer() trainer.train() """ repaired = apply_unified_diff("broken stuff", full_script) assert repaired == full_script