| """End-to-end tests for the OpenEnv-wrapped ForgeEnvironment.""" |
| import pytest |
|
|
| from forgeenv.env.actions import BreakageAction, ForgeAction, RepairAction |
| from forgeenv.env.diff_utils import apply_unified_diff, make_unified_diff |
| from forgeenv.env.forge_environment import ForgeEnvironment |
| from forgeenv.env.observations import ForgeObservation |
| from forgeenv.roles.teacher import Teacher |
|
|
|
|
| def test_reset_returns_drift_gen_observation(): |
| env = ForgeEnvironment(seed=0) |
| obs = env.reset() |
|
|
| assert isinstance(obs, ForgeObservation) |
| assert obs.current_phase == "drift_gen" |
| assert obs.done is False |
| assert obs.script_content |
| from forgeenv.primitives.breakage_primitives import PRIMITIVE_REGISTRY |
|
|
| assert obs.target_category in PRIMITIVE_REGISTRY |
| assert obs.library_versions |
|
|
|
|
| def test_full_episode_lifecycle(): |
| env = ForgeEnvironment(seed=0) |
| obs = env.reset() |
| initial_script = obs.script_content |
|
|
| breakage = ForgeAction( |
| breakage=BreakageAction( |
| primitive_type="DeprecateImport", |
| params={"old_module": "import torch", "new_module": "import torch.legacy"}, |
| ) |
| ) |
| obs2 = env.step(breakage) |
| assert obs2.current_phase == "repair" |
| assert obs2.done is False |
| assert obs2.info.get("breakage_primitive") == "DeprecateImport" |
| assert obs2.error_trace is not None |
|
|
| repair = ForgeAction(repair=RepairAction(unified_diff=initial_script)) |
| obs3 = env.step(repair) |
| assert obs3.current_phase == "done" |
| assert obs3.done is True |
| assert obs3.reward is not None |
| assert isinstance(obs3.reward_breakdown, dict) |
| assert isinstance(obs3.held_out_breakdown, dict) |
| assert {"executed_cleanly", "checkpoint_valid"} <= set(obs3.held_out_breakdown) |
|
|
|
|
| def test_invalid_action_for_phase(): |
| env = ForgeEnvironment(seed=0) |
| env.reset() |
| repair_first = ForgeAction(repair=RepairAction(unified_diff="print('hi')")) |
| obs = env.step(repair_first) |
| |
| assert obs.done is True |
| assert obs.info.get("error") is not None |
|
|
|
|
| def test_step_before_reset_returns_error(): |
| env = ForgeEnvironment(seed=0) |
| breakage = ForgeAction( |
| breakage=BreakageAction(primitive_type="RenameApiCall", params={}) |
| ) |
| obs = env.step(breakage) |
| assert obs.done is True |
| assert obs.info.get("error") |
|
|
|
|
| def test_state_property_is_dict(): |
| env = ForgeEnvironment(seed=0) |
| env.reset() |
| state = env.state |
| assert isinstance(state, dict) |
| assert "phase" in state and "library_versions" in state and "teacher" in state |
|
|
|
|
| def test_action_validation_rejects_both_or_neither(): |
| with pytest.raises(Exception): |
| ForgeAction() |
| with pytest.raises(Exception): |
| ForgeAction( |
| breakage=BreakageAction(primitive_type="RenameApiCall", params={}), |
| repair=RepairAction(unified_diff="x"), |
| ) |
|
|
|
|
| def test_teacher_updates_after_episode(): |
| teacher = Teacher(categories=["RenameApiCall"]) |
| env = ForgeEnvironment(teacher=teacher, seed=0) |
| env.reset() |
| env.step( |
| ForgeAction( |
| breakage=BreakageAction( |
| primitive_type="RenameApiCall", |
| params={"old_name": "x", "new_name": "y"}, |
| ) |
| ) |
| ) |
| env.step(ForgeAction(repair=RepairAction(unified_diff="print('noop')"))) |
| state = teacher.get_state() |
| assert any(s["attempts"] >= 1 for s in state.values()) |
|
|
|
|
| def test_unified_diff_round_trip(): |
| before = "hello\nworld\n" |
| after = "hello\nplanet\n" |
| diff = make_unified_diff(before, after) |
| repaired = apply_unified_diff(before, diff) |
| assert repaired == after |
|
|
|
|
| def test_unified_diff_full_script_replacement(): |
| full_script = """import torch |
| from transformers import Trainer |
| trainer = Trainer() |
| trainer.train() |
| """ |
| repaired = apply_unified_diff("broken stuff", full_script) |
| assert repaired == full_script |
|
|