Spaces:
Running
Running
File size: 1,729 Bytes
3eae4cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import numpy as np
from rl.gov_workflow_env import GovWorkflowGymEnv
def test_gym_wrapper_reset_step_and_core_env_access():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=101,
hard_action_mask=True,
)
obs, info = env.reset(seed=101)
assert isinstance(obs, np.ndarray)
assert obs.shape == env.observation_space.shape
assert isinstance(info, dict)
assert env.core_env is not None
masks = env.action_masks()
assert isinstance(masks, np.ndarray)
assert masks.dtype == bool
assert masks.shape == (env.action_space.n,)
valid_actions = np.flatnonzero(masks)
assert valid_actions.size > 0
obs2, reward, terminated, truncated, info2 = env.step(int(valid_actions[0]))
assert isinstance(obs2, np.ndarray)
assert obs2.shape == env.observation_space.shape
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
assert isinstance(info2, dict)
assert "requested_action_idx" in info2
assert "executed_action_idx" in info2
assert "action_mask_applied" in info2
def test_gym_wrapper_hard_mask_sanitizes_invalid_action_when_available():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=202,
hard_action_mask=True,
)
env.reset(seed=202)
masks = env.action_masks()
invalid_actions = np.flatnonzero(~masks)
if invalid_actions.size == 0:
return
invalid_idx = int(invalid_actions[0])
_, _, _, _, info = env.step(invalid_idx)
assert info["requested_action_idx"] == invalid_idx
assert info["executed_action_idx"] != invalid_idx
assert info["action_mask_applied"] is True |