File size: 3,396 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Phase 3 tests for recurrent evaluation helpers."""

from __future__ import annotations

import numpy as np

import rl.evaluate as eval_mod
from rl.evaluate import TaskEvalResult, compare_recurrent_vs_flat, predict_recurrent_action


class _FakeRecurrentModel:
    def __init__(self):
        self.seen_states = []

    def predict(self, obs, state=None, episode_start=None, deterministic=True):
        self.seen_states.append(state)
        if episode_start is not None and bool(np.asarray(episode_start).item()):
            next_state = 0
        elif state is None:
            next_state = 0
        else:
            next_state = int(state) + 1
        return np.array([18]), next_state


def test_recurrent_policy_hidden_state_persists_across_steps() -> None:
    model = _FakeRecurrentModel()
    obs = np.zeros(4, dtype=np.float32)
    masks = np.array([True] * 28)

    action_1, state_1 = predict_recurrent_action(
        model=model,
        obs=obs,
        lstm_state=None,
        episode_start=np.array([False], dtype=bool),
        masks=masks,
    )
    action_2, state_2 = predict_recurrent_action(
        model=model,
        obs=obs,
        lstm_state=state_1,
        episode_start=np.array([False], dtype=bool),
        masks=masks,
    )

    assert action_1 == 18
    assert action_2 == 18
    assert state_1 == 0
    assert state_2 == 1
    assert model.seen_states == [None, 0]


def test_lstm_reset_on_episode_boundary() -> None:
    model = _FakeRecurrentModel()
    obs = np.zeros(4, dtype=np.float32)
    masks = np.array([True] * 28)

    _, state_1 = predict_recurrent_action(
        model=model,
        obs=obs,
        lstm_state=5,
        episode_start=np.array([False], dtype=bool),
        masks=masks,
    )
    _, state_2 = predict_recurrent_action(
        model=model,
        obs=obs,
        lstm_state=state_1,
        episode_start=np.array([True], dtype=bool),
        masks=masks,
    )

    assert state_1 == 6
    assert state_2 == 0


def test_score_recurrent_geq_flat_ppo_on_medium(monkeypatch) -> None:
    def _fake_evaluate_model(model_path, task_ids, n_episodes, verbose, model_type):
        assert task_ids == ["mixed_urgency_medium"]
        if model_type == "maskable":
            score = 0.60
        else:
            score = 0.63
        return [
            TaskEvalResult(
                task_id="mixed_urgency_medium",
                seed=22,
                grader_score=score,
                total_reward=200.0,
                total_steps=50,
                total_completed=40,
                total_sla_breaches=20,
                fairness_gap=0.1,
            )
        ]

    monkeypatch.setattr(eval_mod, "evaluate_model", _fake_evaluate_model)

    comparison = compare_recurrent_vs_flat(
        flat_model_path="flat.zip",
        recurrent_model_path="recurrent.zip",
        task_id="mixed_urgency_medium",
        n_episodes=3,
    )

    assert comparison["recurrent"] >= comparison["flat"]


def test_invalid_action_prefers_advance_time_fallback() -> None:
    model = _FakeRecurrentModel()
    obs = np.zeros(4, dtype=np.float32)
    masks = np.array([False] * 28)
    masks[18] = True

    action, _ = predict_recurrent_action(
        model=model,
        obs=obs,
        lstm_state=None,
        episode_start=np.array([False], dtype=bool),
        masks=masks,
    )
    assert action == 18