File size: 4,451 Bytes
c9d1b27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from env.gym_wrapper import MahoragaGymEnv

PASS = 0
FAIL = 0


def check(name, condition):
    global PASS, FAIL
    if condition:
        PASS += 1
        print(f"  [PASS] {name}")
    else:
        FAIL += 1
        print(f"  [FAIL] {name}")


def test_reset_returns_valid_observation():
    print("\n--- Test: Reset Returns Valid Observation ---")
    env = MahoragaGymEnv()
    obs, info = env.reset()

    check("Reset returns tuple of 2", isinstance(obs, dict) and isinstance(info, dict))
    check("Observation has agent_hp", "agent_hp" in obs)
    check("Observation has enemy_hp", "enemy_hp" in obs)
    check("Observation has resistances", "resistances" in obs)
    check("Observation has last_enemy_attack_type", "last_enemy_attack_type" in obs)
    check("Observation has last_enemy_subtype", "last_enemy_subtype" in obs)
    check("Observation has last_action", "last_action" in obs)
    check("Observation has turn_number", "turn_number" in obs)


def test_observation_matches_space():
    print("\n--- Test: Observation Matches Observation Space ---")
    env = MahoragaGymEnv()
    obs, _ = env.reset()

    check("Initial obs is in observation_space", env.observation_space.contains(obs))

    # Take a step and check again
    obs2, _, _, _, _ = env.step(0)
    check("Post-step obs is in observation_space", env.observation_space.contains(obs2))

    # Judgment Strike
    obs3, _, _, _, _ = env.step(3)
    check("Post-judgment obs is in observation_space", env.observation_space.contains(obs3))


def test_step_returns_correct_format():
    print("\n--- Test: Step Returns Correct Tuple Format ---")
    env = MahoragaGymEnv()
    env.reset()
    result = env.step(0)

    check("Step returns tuple of 5", len(result) == 5)
    obs, reward, terminated, truncated, info = result
    check("obs is dict", isinstance(obs, dict))
    check("reward is float", isinstance(reward, (int, float)))
    check("terminated is bool", isinstance(terminated, bool))
    check("truncated is bool", isinstance(truncated, bool))
    check("truncated is False", truncated is False)
    check("info is dict", isinstance(info, dict))
    check("info has reward_breakdown", "reward_breakdown" in info)


def test_action_space():
    print("\n--- Test: Action Space ---")
    env = MahoragaGymEnv()
    check("Action space is Discrete(5)", env.action_space.n == 5)

    # All valid actions should work
    for a in range(5):
        check(f"Action {a} is in action_space", env.action_space.contains(a))

    check("Action 5 is NOT in action_space", not env.action_space.contains(5))
    check("Action -1 is NOT in action_space", not env.action_space.contains(-1))


def test_multiple_steps():
    print("\n--- Test: Multiple Steps Run Without Crash ---")
    env = MahoragaGymEnv()
    obs, _ = env.reset()

    steps = 0
    done = False
    while not done:
        action = env.action_space.sample()
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        steps += 1
        # Verify observation stays valid every step
        if not env.observation_space.contains(obs):
            check(f"Obs valid at step {steps}", False)
            return

    check(f"Ran {steps} steps without crash", True)
    check("Episode terminated", terminated is True)
    check("All observations valid throughout episode", True)


def test_reset_after_episode():
    print("\n--- Test: Reset After Episode ---")
    env = MahoragaGymEnv()

    # Run full episode
    env.reset()
    done = False
    while not done:
        _, _, done, _, _ = env.step(env.action_space.sample())

    # Reset and run again
    obs, info = env.reset()
    check("Reset after episode returns valid obs", env.observation_space.contains(obs))
    check("Agent HP reset to max", obs["agent_hp"][0] == 1200)
    check("Turn number reset to 0", obs["turn_number"][0] == 0)


if __name__ == "__main__":
    print("=" * 50)
    print("  MahoragaGymEnv Wrapper Tests")
    print("=" * 50)

    test_reset_returns_valid_observation()
    test_observation_matches_space()
    test_step_returns_correct_format()
    test_action_space()
    test_multiple_steps()
    test_reset_after_episode()

    print("\n" + "=" * 50)
    print(f"  Results: {PASS} passed, {FAIL} failed")
    print("=" * 50)

    if FAIL > 0:
        sys.exit(1)