Gov_Workflow_RL / tests /test_rl_evaluate.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
"""Phase 3 tests for recurrent evaluation helpers."""
from __future__ import annotations
import numpy as np
import rl.evaluate as eval_mod
from rl.evaluate import TaskEvalResult, compare_recurrent_vs_flat, predict_recurrent_action
class _FakeRecurrentModel:
def __init__(self):
self.seen_states = []
def predict(self, obs, state=None, episode_start=None, deterministic=True):
self.seen_states.append(state)
if episode_start is not None and bool(np.asarray(episode_start).item()):
next_state = 0
elif state is None:
next_state = 0
else:
next_state = int(state) + 1
return np.array([18]), next_state
def test_recurrent_policy_hidden_state_persists_across_steps() -> None:
model = _FakeRecurrentModel()
obs = np.zeros(4, dtype=np.float32)
masks = np.array([True] * 28)
action_1, state_1 = predict_recurrent_action(
model=model,
obs=obs,
lstm_state=None,
episode_start=np.array([False], dtype=bool),
masks=masks,
)
action_2, state_2 = predict_recurrent_action(
model=model,
obs=obs,
lstm_state=state_1,
episode_start=np.array([False], dtype=bool),
masks=masks,
)
assert action_1 == 18
assert action_2 == 18
assert state_1 == 0
assert state_2 == 1
assert model.seen_states == [None, 0]
def test_lstm_reset_on_episode_boundary() -> None:
model = _FakeRecurrentModel()
obs = np.zeros(4, dtype=np.float32)
masks = np.array([True] * 28)
_, state_1 = predict_recurrent_action(
model=model,
obs=obs,
lstm_state=5,
episode_start=np.array([False], dtype=bool),
masks=masks,
)
_, state_2 = predict_recurrent_action(
model=model,
obs=obs,
lstm_state=state_1,
episode_start=np.array([True], dtype=bool),
masks=masks,
)
assert state_1 == 6
assert state_2 == 0
def test_score_recurrent_geq_flat_ppo_on_medium(monkeypatch) -> None:
def _fake_evaluate_model(model_path, task_ids, n_episodes, verbose, model_type):
assert task_ids == ["mixed_urgency_medium"]
if model_type == "maskable":
score = 0.60
else:
score = 0.63
return [
TaskEvalResult(
task_id="mixed_urgency_medium",
seed=22,
grader_score=score,
total_reward=200.0,
total_steps=50,
total_completed=40,
total_sla_breaches=20,
fairness_gap=0.1,
)
]
monkeypatch.setattr(eval_mod, "evaluate_model", _fake_evaluate_model)
comparison = compare_recurrent_vs_flat(
flat_model_path="flat.zip",
recurrent_model_path="recurrent.zip",
task_id="mixed_urgency_medium",
n_episodes=3,
)
assert comparison["recurrent"] >= comparison["flat"]
def test_invalid_action_prefers_advance_time_fallback() -> None:
model = _FakeRecurrentModel()
obs = np.zeros(4, dtype=np.float32)
masks = np.array([False] * 28)
masks[18] = True
action, _ = predict_recurrent_action(
model=model,
obs=obs,
lstm_state=None,
episode_start=np.array([False], dtype=bool),
masks=masks,
)
assert action == 18