Gov_Workflow_RL / tests /test_gym_wrapper_integration.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
import numpy as np
from rl.gov_workflow_env import GovWorkflowGymEnv
def test_gym_wrapper_reset_step_and_core_env_access():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=101,
hard_action_mask=True,
)
obs, info = env.reset(seed=101)
assert isinstance(obs, np.ndarray)
assert obs.shape == env.observation_space.shape
assert isinstance(info, dict)
assert env.core_env is not None
masks = env.action_masks()
assert isinstance(masks, np.ndarray)
assert masks.dtype == bool
assert masks.shape == (env.action_space.n,)
valid_actions = np.flatnonzero(masks)
assert valid_actions.size > 0
obs2, reward, terminated, truncated, info2 = env.step(int(valid_actions[0]))
assert isinstance(obs2, np.ndarray)
assert obs2.shape == env.observation_space.shape
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
assert isinstance(info2, dict)
assert "requested_action_idx" in info2
assert "executed_action_idx" in info2
assert "action_mask_applied" in info2
def test_gym_wrapper_hard_mask_sanitizes_invalid_action_when_available():
env = GovWorkflowGymEnv(
task_id="district_backlog_easy",
seed=202,
hard_action_mask=True,
)
env.reset(seed=202)
masks = env.action_masks()
invalid_actions = np.flatnonzero(~masks)
if invalid_actions.size == 0:
return
invalid_idx = int(invalid_actions[0])
_, _, _, _, info = env.step(invalid_idx)
assert info["requested_action_idx"] == invalid_idx
assert info["executed_action_idx"] != invalid_idx
assert info["action_mask_applied"] is True