File size: 10,704 Bytes
df97e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
Custom SB3 callbacks for Gov Workflow RL training.

GovWorkflowEvalCallback  -- MaskableEvalCallback + grader score logging
CostMonitorCallback      -- per-rollout cost constraint logging to TensorBoard
"""

from __future__ import annotations

import os
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from typing import Any

from rl.gov_workflow_env import GovWorkflowGymEnv
from rl.cost_tracker import THRESHOLD_SLA, THRESHOLD_FAIRNESS


class GovWorkflowEvalCallback(MaskableEvalCallback):
    """
    Extends MaskableEvalCallback:
    1. Runs the deterministic grader after each eval.
    2. Logs grader score to TensorBoard.
    3. Saves best model by grader score (not just mean reward).
    """

    def __init__(
        self,
        eval_env:             GovWorkflowGymEnv,
        eval_freq:            int = 2048,
        n_eval_episodes:      int = 5,
        grader_eval_freq_multiplier: int = 4,
        grader_eval_max_steps: int | None = None,
        best_model_save_path: str = "results/best_model",
        log_path:             str = "results/eval_logs",
        task_id:              str = "district_backlog_easy",
        verbose:              int = 1,
    ):
        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            best_model_save_path=best_model_save_path,
            log_path=log_path,
            verbose=verbose,
            warn=False,
        )
        self.task_id = task_id
        self.grader_eval_freq_multiplier = max(1, int(grader_eval_freq_multiplier))
        self.grader_eval_max_steps = grader_eval_max_steps
        self._best_grader_score = -np.inf
        os.makedirs(best_model_save_path, exist_ok=True)
        os.makedirs(log_path, exist_ok=True)

    def _on_step(self) -> bool:
        eval_due = self.eval_freq > 0 and self.n_calls % self.eval_freq == 0
        result = super()._on_step()
        if eval_due:
            mean_reward = float(getattr(self, "last_mean_reward", 0.0) or 0.0)
            std_reward = 0.0
            try:
                if self.evaluations_results and len(self.evaluations_results) > 0:
                    latest = self.evaluations_results[-1]
                    if latest is not None and len(latest) > 0:
                        std_reward = float(np.std(latest))
            except Exception:
                std_reward = 0.0
            # Stable line format for live parser in backend/frontend.
            print(
                f"Eval num_timesteps={int(self.num_timesteps)}, "
                f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}",
                flush=True,
            )

        grader_eval_freq = max(self.eval_freq * self.grader_eval_freq_multiplier, 1)
        if self.eval_freq > 0 and self.n_calls % grader_eval_freq == 0:
            grader_score = self._run_grader_eval()
            if self.logger:
                self.logger.record("eval/grader_score", grader_score)
            if grader_score > self._best_grader_score:
                self._best_grader_score = grader_score
                save_path = os.path.join(
                    self.best_model_save_path, f"best_grader_{self.task_id}"
                )
                self.model.save(save_path)
                if self.verbose:
                    print(f"[Eval] New best grader score: {grader_score:.4f} -> {save_path}")
        return result

    def _run_grader_eval(self) -> float:
        try:
            from app.graders import grade_episode
            from app.tasks import TASKS
            task_cfg = TASKS.get(self.task_id)
            if task_cfg is None:
                return 0.0
            max_steps = (
                int(self.grader_eval_max_steps)
                if self.grader_eval_max_steps is not None
                else max(1, int(task_cfg.max_days) * 10)
            )
            env = GovWorkflowGymEnv(task_id=self.task_id, seed=task_cfg.seed, hard_action_mask=True)
            obs, _ = env.reset()
            done = False
            steps = 0
            while not done:
                masks = np.asarray(env.action_masks(), dtype=bool).reshape(-1)
                action, _ = self.model.predict(obs, action_masks=masks, deterministic=True)
                obs, _, terminated, truncated, _ = env.step(int(action))
                done = terminated or truncated
                steps += 1
                if steps >= max_steps and not done:
                    break
            result = grade_episode(env._core_env.state())
            return float(result.score)
        except Exception as e:
            if self.verbose:
                print(f"[Eval] Grader eval failed: {e}")
            return 0.0


class CostMonitorCallback(BaseCallback):
    """
    Monitors SLA and fairness cost signals per rollout.
    Phase 1-3: diagnostic only.
    Phase 4:   feeds into Lagrangian multiplier updates.
    """

    def __init__(self, verbose: int = 0):
        super().__init__(verbose)
        self._episode_costs: list[dict] = []
        self._ep_sla:  list[float] = []
        self._ep_fair: list[float] = []
        self._ep_mask_applied: list[float] = []

    def _on_step(self) -> bool:
        for info, done in zip(
            self.locals.get("infos", []),
            self.locals.get("dones", []),
        ):
            rb = info.get("reward_breakdown", {})
            self._ep_sla.append( abs(float(rb.get("sla_penalty",      0.0))))
            self._ep_fair.append(abs(float(rb.get("fairness_penalty",  0.0))))
            self._ep_mask_applied.append(float(bool(info.get("action_mask_applied", False))))
            if done:
                mean_sla  = float(np.mean(self._ep_sla))  if self._ep_sla  else 0.0
                mean_fair = float(np.mean(self._ep_fair)) if self._ep_fair else 0.0
                mask_rate = float(np.mean(self._ep_mask_applied)) if self._ep_mask_applied else 0.0
                self._episode_costs.append({"sla": mean_sla, "fairness": mean_fair})
                self.logger.record("costs/episode_mean_sla_penalty",     mean_sla)
                self.logger.record("costs/episode_mean_fairness_penalty", mean_fair)
                self.logger.record("costs/sla_threshold_violated",       float(mean_sla  > THRESHOLD_SLA))
                self.logger.record("costs/fairness_threshold_violated",  float(mean_fair > THRESHOLD_FAIRNESS))
                self.logger.record("costs/episode_action_mask_applied_rate", mask_rate)
                self._ep_sla.clear()
                self._ep_fair.clear()
                self._ep_mask_applied.clear()
        return True

    def _on_training_end(self) -> None:
        if not self._episode_costs:
            return
        all_sla  = [c["sla"]      for c in self._episode_costs]
        all_fair = [c["fairness"] for c in self._episode_costs]
        print(
            f"\n[CostMonitor] mean SLA penalty: {np.mean(all_sla):.4f} "
            f"(threshold={THRESHOLD_SLA}), "
            f"mean fairness penalty: {np.mean(all_fair):.4f} "
            f"(threshold={THRESHOLD_FAIRNESS})"
        )


class RecurrentEvalCallback(BaseCallback):
    """
    Periodic evaluation callback for RecurrentPPO.

    We evaluate with deterministic inference and enforce action masks at
    inference time before env.step().
    """

    def __init__(
        self,
        eval_env: GovWorkflowGymEnv,
        eval_freq: int = 2048,
        n_eval_episodes: int = 3,
        best_model_save_path: str = "results/best_model",
        log_path: str = "results/eval_logs",
        task_id: str = "mixed_urgency_medium",
        verbose: int = 1,
    ):
        super().__init__(verbose=verbose)
        self.eval_env = eval_env
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes
        self.best_model_save_path = best_model_save_path
        self.log_path = log_path
        self.task_id = task_id
        self._best_grader_score = -np.inf
        os.makedirs(best_model_save_path, exist_ok=True)
        os.makedirs(log_path, exist_ok=True)

    def _on_step(self) -> bool:
        if self.eval_freq <= 0 or self.n_calls % self.eval_freq != 0:
            return True

        mean_reward, grader_score = self._run_eval()
        self.logger.record("eval/mean_reward", mean_reward)
        self.logger.record("eval/grader_score", grader_score)

        if grader_score > self._best_grader_score:
            self._best_grader_score = grader_score
            save_path = os.path.join(
                self.best_model_save_path, f"best_grader_recurrent_{self.task_id}"
            )
            self.model.save(save_path)
            if self.verbose:
                print(f"[Eval] New best recurrent grader score: {grader_score:.4f} -> {save_path}")
        return True

    def _run_eval(self) -> tuple[float, float]:
        from app.graders import grade_episode
        from app.tasks import TASKS

        task_cfg = TASKS.get(self.task_id)
        if task_cfg is None:
            return 0.0, 0.0

        rewards: list[float] = []
        scores: list[float] = []

        for ep in range(self.n_eval_episodes):
            env = GovWorkflowGymEnv(self.task_id, seed=task_cfg.seed + ep, hard_action_mask=True)
            obs, _ = env.reset()
            done = False
            ep_reward = 0.0
            lstm_state: Any = None
            episode_start = np.array([True], dtype=bool)

            while not done:
                action, lstm_state = self.model.predict(
                    obs,
                    state=lstm_state,
                    episode_start=episode_start,
                    deterministic=True,
                )
                action_idx = int(np.asarray(action).item())
                masks = env.action_masks()
                if action_idx < 0 or action_idx >= masks.shape[0] or not bool(masks[action_idx]):
                    if masks.shape[0] > 18 and bool(masks[18]):
                        action_idx = 18
                    else:
                        valid = np.flatnonzero(masks)
                        if valid.size > 0:
                            action_idx = int(valid[0])

                obs, reward, terminated, truncated, _ = env.step(action_idx)
                ep_reward += float(reward)
                done = bool(terminated or truncated)
                episode_start = np.array([done], dtype=bool)

            result = grade_episode(env._core_env.state())
            rewards.append(ep_reward)
            scores.append(float(result.score))

        return float(np.mean(rewards)), float(np.mean(scores))