File size: 1,729 Bytes
3eae4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np

from rl.gov_workflow_env import GovWorkflowGymEnv


def test_gym_wrapper_reset_step_and_core_env_access():
    env = GovWorkflowGymEnv(
        task_id="district_backlog_easy",
        seed=101,
        hard_action_mask=True,
    )

    obs, info = env.reset(seed=101)

    assert isinstance(obs, np.ndarray)
    assert obs.shape == env.observation_space.shape
    assert isinstance(info, dict)
    assert env.core_env is not None

    masks = env.action_masks()
    assert isinstance(masks, np.ndarray)
    assert masks.dtype == bool
    assert masks.shape == (env.action_space.n,)

    valid_actions = np.flatnonzero(masks)
    assert valid_actions.size > 0

    obs2, reward, terminated, truncated, info2 = env.step(int(valid_actions[0]))

    assert isinstance(obs2, np.ndarray)
    assert obs2.shape == env.observation_space.shape
    assert isinstance(reward, float)
    assert isinstance(terminated, bool)
    assert isinstance(truncated, bool)
    assert isinstance(info2, dict)
    assert "requested_action_idx" in info2
    assert "executed_action_idx" in info2
    assert "action_mask_applied" in info2


def test_gym_wrapper_hard_mask_sanitizes_invalid_action_when_available():
    env = GovWorkflowGymEnv(
        task_id="district_backlog_easy",
        seed=202,
        hard_action_mask=True,
    )
    env.reset(seed=202)
    masks = env.action_masks()

    invalid_actions = np.flatnonzero(~masks)
    if invalid_actions.size == 0:
        return

    invalid_idx = int(invalid_actions[0])
    _, _, _, _, info = env.step(invalid_idx)

    assert info["requested_action_idx"] == invalid_idx
    assert info["executed_action_idx"] != invalid_idx
    assert info["action_mask_applied"] is True