""" Custom SB3 callbacks for Gov Workflow RL training. GovWorkflowEvalCallback -- MaskableEvalCallback + grader score logging CostMonitorCallback -- per-rollout cost constraint logging to TensorBoard """ from __future__ import annotations import os import numpy as np from stable_baselines3.common.callbacks import BaseCallback from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback from typing import Any from rl.gov_workflow_env import GovWorkflowGymEnv from rl.cost_tracker import THRESHOLD_SLA, THRESHOLD_FAIRNESS class GovWorkflowEvalCallback(MaskableEvalCallback): """ Extends MaskableEvalCallback: 1. Runs the deterministic grader after each eval. 2. Logs grader score to TensorBoard. 3. Saves best model by grader score (not just mean reward). """ def __init__( self, eval_env: GovWorkflowGymEnv, eval_freq: int = 2048, n_eval_episodes: int = 5, grader_eval_freq_multiplier: int = 4, grader_eval_max_steps: int | None = None, best_model_save_path: str = "results/best_model", log_path: str = "results/eval_logs", task_id: str = "district_backlog_easy", verbose: int = 1, ): super().__init__( eval_env=eval_env, n_eval_episodes=n_eval_episodes, eval_freq=eval_freq, best_model_save_path=best_model_save_path, log_path=log_path, verbose=verbose, warn=False, ) self.task_id = task_id self.grader_eval_freq_multiplier = max(1, int(grader_eval_freq_multiplier)) self.grader_eval_max_steps = grader_eval_max_steps self._best_grader_score = -np.inf os.makedirs(best_model_save_path, exist_ok=True) os.makedirs(log_path, exist_ok=True) def _on_step(self) -> bool: eval_due = self.eval_freq > 0 and self.n_calls % self.eval_freq == 0 result = super()._on_step() if eval_due: mean_reward = float(getattr(self, "last_mean_reward", 0.0) or 0.0) std_reward = 0.0 try: if self.evaluations_results and len(self.evaluations_results) > 0: latest = self.evaluations_results[-1] if latest is not None and len(latest) > 0: std_reward = float(np.std(latest)) except Exception: std_reward = 0.0 # Stable line format for live parser in backend/frontend. print( f"Eval num_timesteps={int(self.num_timesteps)}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}", flush=True, ) grader_eval_freq = max(self.eval_freq * self.grader_eval_freq_multiplier, 1) if self.eval_freq > 0 and self.n_calls % grader_eval_freq == 0: grader_score = self._run_grader_eval() if self.logger: self.logger.record("eval/grader_score", grader_score) if grader_score > self._best_grader_score: self._best_grader_score = grader_score save_path = os.path.join( self.best_model_save_path, f"best_grader_{self.task_id}" ) self.model.save(save_path) if self.verbose: print(f"[Eval] New best grader score: {grader_score:.4f} -> {save_path}") return result def _run_grader_eval(self) -> float: try: from app.graders import grade_episode from app.tasks import TASKS task_cfg = TASKS.get(self.task_id) if task_cfg is None: return 0.0 max_steps = ( int(self.grader_eval_max_steps) if self.grader_eval_max_steps is not None else max(1, int(task_cfg.max_days) * 10) ) env = GovWorkflowGymEnv(task_id=self.task_id, seed=task_cfg.seed, hard_action_mask=True) obs, _ = env.reset() done = False steps = 0 while not done: masks = np.asarray(env.action_masks(), dtype=bool).reshape(-1) action, _ = self.model.predict(obs, action_masks=masks, deterministic=True) obs, _, terminated, truncated, _ = env.step(int(action)) done = terminated or truncated steps += 1 if steps >= max_steps and not done: break result = grade_episode(env._core_env.state()) return float(result.score) except Exception as e: if self.verbose: print(f"[Eval] Grader eval failed: {e}") return 0.0 class CostMonitorCallback(BaseCallback): """ Monitors SLA and fairness cost signals per rollout. Phase 1-3: diagnostic only. Phase 4: feeds into Lagrangian multiplier updates. """ def __init__(self, verbose: int = 0): super().__init__(verbose) self._episode_costs: list[dict] = [] self._ep_sla: list[float] = [] self._ep_fair: list[float] = [] self._ep_mask_applied: list[float] = [] def _on_step(self) -> bool: for info, done in zip( self.locals.get("infos", []), self.locals.get("dones", []), ): rb = info.get("reward_breakdown", {}) self._ep_sla.append( abs(float(rb.get("sla_penalty", 0.0)))) self._ep_fair.append(abs(float(rb.get("fairness_penalty", 0.0)))) self._ep_mask_applied.append(float(bool(info.get("action_mask_applied", False)))) if done: mean_sla = float(np.mean(self._ep_sla)) if self._ep_sla else 0.0 mean_fair = float(np.mean(self._ep_fair)) if self._ep_fair else 0.0 mask_rate = float(np.mean(self._ep_mask_applied)) if self._ep_mask_applied else 0.0 self._episode_costs.append({"sla": mean_sla, "fairness": mean_fair}) self.logger.record("costs/episode_mean_sla_penalty", mean_sla) self.logger.record("costs/episode_mean_fairness_penalty", mean_fair) self.logger.record("costs/sla_threshold_violated", float(mean_sla > THRESHOLD_SLA)) self.logger.record("costs/fairness_threshold_violated", float(mean_fair > THRESHOLD_FAIRNESS)) self.logger.record("costs/episode_action_mask_applied_rate", mask_rate) self._ep_sla.clear() self._ep_fair.clear() self._ep_mask_applied.clear() return True def _on_training_end(self) -> None: if not self._episode_costs: return all_sla = [c["sla"] for c in self._episode_costs] all_fair = [c["fairness"] for c in self._episode_costs] print( f"\n[CostMonitor] mean SLA penalty: {np.mean(all_sla):.4f} " f"(threshold={THRESHOLD_SLA}), " f"mean fairness penalty: {np.mean(all_fair):.4f} " f"(threshold={THRESHOLD_FAIRNESS})" ) class RecurrentEvalCallback(BaseCallback): """ Periodic evaluation callback for RecurrentPPO. We evaluate with deterministic inference and enforce action masks at inference time before env.step(). """ def __init__( self, eval_env: GovWorkflowGymEnv, eval_freq: int = 2048, n_eval_episodes: int = 3, best_model_save_path: str = "results/best_model", log_path: str = "results/eval_logs", task_id: str = "mixed_urgency_medium", verbose: int = 1, ): super().__init__(verbose=verbose) self.eval_env = eval_env self.eval_freq = eval_freq self.n_eval_episodes = n_eval_episodes self.best_model_save_path = best_model_save_path self.log_path = log_path self.task_id = task_id self._best_grader_score = -np.inf os.makedirs(best_model_save_path, exist_ok=True) os.makedirs(log_path, exist_ok=True) def _on_step(self) -> bool: if self.eval_freq <= 0 or self.n_calls % self.eval_freq != 0: return True mean_reward, grader_score = self._run_eval() self.logger.record("eval/mean_reward", mean_reward) self.logger.record("eval/grader_score", grader_score) if grader_score > self._best_grader_score: self._best_grader_score = grader_score save_path = os.path.join( self.best_model_save_path, f"best_grader_recurrent_{self.task_id}" ) self.model.save(save_path) if self.verbose: print(f"[Eval] New best recurrent grader score: {grader_score:.4f} -> {save_path}") return True def _run_eval(self) -> tuple[float, float]: from app.graders import grade_episode from app.tasks import TASKS task_cfg = TASKS.get(self.task_id) if task_cfg is None: return 0.0, 0.0 rewards: list[float] = [] scores: list[float] = [] for ep in range(self.n_eval_episodes): env = GovWorkflowGymEnv(self.task_id, seed=task_cfg.seed + ep, hard_action_mask=True) obs, _ = env.reset() done = False ep_reward = 0.0 lstm_state: Any = None episode_start = np.array([True], dtype=bool) while not done: action, lstm_state = self.model.predict( obs, state=lstm_state, episode_start=episode_start, deterministic=True, ) action_idx = int(np.asarray(action).item()) masks = env.action_masks() if action_idx < 0 or action_idx >= masks.shape[0] or not bool(masks[action_idx]): if masks.shape[0] > 18 and bool(masks[18]): action_idx = 18 else: valid = np.flatnonzero(masks) if valid.size > 0: action_idx = int(valid[0]) obs, reward, terminated, truncated, _ = env.step(action_idx) ep_reward += float(reward) done = bool(terminated or truncated) episode_start = np.array([done], dtype=bool) result = grade_episode(env._core_env.state()) rewards.append(ep_reward) scores.append(float(result.score)) return float(np.mean(rewards)), float(np.mean(scores))