Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from env.gym_wrapper import MahoragaGymEnv | |
| PASS = 0 | |
| FAIL = 0 | |
| def check(name, condition): | |
| global PASS, FAIL | |
| if condition: | |
| PASS += 1 | |
| print(f" [PASS] {name}") | |
| else: | |
| FAIL += 1 | |
| print(f" [FAIL] {name}") | |
| def test_reset_returns_valid_observation(): | |
| print("\n--- Test: Reset Returns Valid Observation ---") | |
| env = MahoragaGymEnv() | |
| obs, info = env.reset() | |
| check("Reset returns tuple of 2", isinstance(obs, dict) and isinstance(info, dict)) | |
| check("Observation has agent_hp", "agent_hp" in obs) | |
| check("Observation has enemy_hp", "enemy_hp" in obs) | |
| check("Observation has resistances", "resistances" in obs) | |
| check("Observation has last_enemy_attack_type", "last_enemy_attack_type" in obs) | |
| check("Observation has last_enemy_subtype", "last_enemy_subtype" in obs) | |
| check("Observation has last_action", "last_action" in obs) | |
| check("Observation has turn_number", "turn_number" in obs) | |
| def test_observation_matches_space(): | |
| print("\n--- Test: Observation Matches Observation Space ---") | |
| env = MahoragaGymEnv() | |
| obs, _ = env.reset() | |
| check("Initial obs is in observation_space", env.observation_space.contains(obs)) | |
| # Take a step and check again | |
| obs2, _, _, _, _ = env.step(0) | |
| check("Post-step obs is in observation_space", env.observation_space.contains(obs2)) | |
| # Judgment Strike | |
| obs3, _, _, _, _ = env.step(3) | |
| check("Post-judgment obs is in observation_space", env.observation_space.contains(obs3)) | |
| def test_step_returns_correct_format(): | |
| print("\n--- Test: Step Returns Correct Tuple Format ---") | |
| env = MahoragaGymEnv() | |
| env.reset() | |
| result = env.step(0) | |
| check("Step returns tuple of 5", len(result) == 5) | |
| obs, reward, terminated, truncated, info = result | |
| check("obs is dict", isinstance(obs, dict)) | |
| check("reward is float", isinstance(reward, (int, float))) | |
| check("terminated is bool", isinstance(terminated, bool)) | |
| check("truncated is bool", isinstance(truncated, bool)) | |
| check("truncated is False", truncated is False) | |
| check("info is dict", isinstance(info, dict)) | |
| check("info has reward_breakdown", "reward_breakdown" in info) | |
| def test_action_space(): | |
| print("\n--- Test: Action Space ---") | |
| env = MahoragaGymEnv() | |
| check("Action space is Discrete(5)", env.action_space.n == 5) | |
| # All valid actions should work | |
| for a in range(5): | |
| check(f"Action {a} is in action_space", env.action_space.contains(a)) | |
| check("Action 5 is NOT in action_space", not env.action_space.contains(5)) | |
| check("Action -1 is NOT in action_space", not env.action_space.contains(-1)) | |
| def test_multiple_steps(): | |
| print("\n--- Test: Multiple Steps Run Without Crash ---") | |
| env = MahoragaGymEnv() | |
| obs, _ = env.reset() | |
| steps = 0 | |
| done = False | |
| while not done: | |
| action = env.action_space.sample() | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| done = terminated or truncated | |
| steps += 1 | |
| # Verify observation stays valid every step | |
| if not env.observation_space.contains(obs): | |
| check(f"Obs valid at step {steps}", False) | |
| return | |
| check(f"Ran {steps} steps without crash", True) | |
| check("Episode terminated", terminated is True) | |
| check("All observations valid throughout episode", True) | |
| def test_reset_after_episode(): | |
| print("\n--- Test: Reset After Episode ---") | |
| env = MahoragaGymEnv() | |
| # Run full episode | |
| env.reset() | |
| done = False | |
| while not done: | |
| _, _, done, _, _ = env.step(env.action_space.sample()) | |
| # Reset and run again | |
| obs, info = env.reset() | |
| check("Reset after episode returns valid obs", env.observation_space.contains(obs)) | |
| check("Agent HP reset to max", obs["agent_hp"][0] == 1200) | |
| check("Turn number reset to 0", obs["turn_number"][0] == 0) | |
| if __name__ == "__main__": | |
| print("=" * 50) | |
| print(" MahoragaGymEnv Wrapper Tests") | |
| print("=" * 50) | |
| test_reset_returns_valid_observation() | |
| test_observation_matches_space() | |
| test_step_returns_correct_format() | |
| test_action_space() | |
| test_multiple_steps() | |
| test_reset_after_episode() | |
| print("\n" + "=" * 50) | |
| print(f" Results: {PASS} passed, {FAIL} failed") | |
| print("=" * 50) | |
| if FAIL > 0: | |
| sys.exit(1) | |