Gov_Workflow_RL / rl /callbacks.py
Siddharaj Shirke
deploy: clean code-only snapshot for HF Space
df97e68
"""
Custom SB3 callbacks for Gov Workflow RL training.
GovWorkflowEvalCallback -- MaskableEvalCallback + grader score logging
CostMonitorCallback -- per-rollout cost constraint logging to TensorBoard
"""
from __future__ import annotations
import os
import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
from typing import Any
from rl.gov_workflow_env import GovWorkflowGymEnv
from rl.cost_tracker import THRESHOLD_SLA, THRESHOLD_FAIRNESS
class GovWorkflowEvalCallback(MaskableEvalCallback):
"""
Extends MaskableEvalCallback:
1. Runs the deterministic grader after each eval.
2. Logs grader score to TensorBoard.
3. Saves best model by grader score (not just mean reward).
"""
def __init__(
self,
eval_env: GovWorkflowGymEnv,
eval_freq: int = 2048,
n_eval_episodes: int = 5,
grader_eval_freq_multiplier: int = 4,
grader_eval_max_steps: int | None = None,
best_model_save_path: str = "results/best_model",
log_path: str = "results/eval_logs",
task_id: str = "district_backlog_easy",
verbose: int = 1,
):
super().__init__(
eval_env=eval_env,
n_eval_episodes=n_eval_episodes,
eval_freq=eval_freq,
best_model_save_path=best_model_save_path,
log_path=log_path,
verbose=verbose,
warn=False,
)
self.task_id = task_id
self.grader_eval_freq_multiplier = max(1, int(grader_eval_freq_multiplier))
self.grader_eval_max_steps = grader_eval_max_steps
self._best_grader_score = -np.inf
os.makedirs(best_model_save_path, exist_ok=True)
os.makedirs(log_path, exist_ok=True)
def _on_step(self) -> bool:
eval_due = self.eval_freq > 0 and self.n_calls % self.eval_freq == 0
result = super()._on_step()
if eval_due:
mean_reward = float(getattr(self, "last_mean_reward", 0.0) or 0.0)
std_reward = 0.0
try:
if self.evaluations_results and len(self.evaluations_results) > 0:
latest = self.evaluations_results[-1]
if latest is not None and len(latest) > 0:
std_reward = float(np.std(latest))
except Exception:
std_reward = 0.0
# Stable line format for live parser in backend/frontend.
print(
f"Eval num_timesteps={int(self.num_timesteps)}, "
f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}",
flush=True,
)
grader_eval_freq = max(self.eval_freq * self.grader_eval_freq_multiplier, 1)
if self.eval_freq > 0 and self.n_calls % grader_eval_freq == 0:
grader_score = self._run_grader_eval()
if self.logger:
self.logger.record("eval/grader_score", grader_score)
if grader_score > self._best_grader_score:
self._best_grader_score = grader_score
save_path = os.path.join(
self.best_model_save_path, f"best_grader_{self.task_id}"
)
self.model.save(save_path)
if self.verbose:
print(f"[Eval] New best grader score: {grader_score:.4f} -> {save_path}")
return result
def _run_grader_eval(self) -> float:
try:
from app.graders import grade_episode
from app.tasks import TASKS
task_cfg = TASKS.get(self.task_id)
if task_cfg is None:
return 0.0
max_steps = (
int(self.grader_eval_max_steps)
if self.grader_eval_max_steps is not None
else max(1, int(task_cfg.max_days) * 10)
)
env = GovWorkflowGymEnv(task_id=self.task_id, seed=task_cfg.seed, hard_action_mask=True)
obs, _ = env.reset()
done = False
steps = 0
while not done:
masks = np.asarray(env.action_masks(), dtype=bool).reshape(-1)
action, _ = self.model.predict(obs, action_masks=masks, deterministic=True)
obs, _, terminated, truncated, _ = env.step(int(action))
done = terminated or truncated
steps += 1
if steps >= max_steps and not done:
break
result = grade_episode(env._core_env.state())
return float(result.score)
except Exception as e:
if self.verbose:
print(f"[Eval] Grader eval failed: {e}")
return 0.0
class CostMonitorCallback(BaseCallback):
"""
Monitors SLA and fairness cost signals per rollout.
Phase 1-3: diagnostic only.
Phase 4: feeds into Lagrangian multiplier updates.
"""
def __init__(self, verbose: int = 0):
super().__init__(verbose)
self._episode_costs: list[dict] = []
self._ep_sla: list[float] = []
self._ep_fair: list[float] = []
self._ep_mask_applied: list[float] = []
def _on_step(self) -> bool:
for info, done in zip(
self.locals.get("infos", []),
self.locals.get("dones", []),
):
rb = info.get("reward_breakdown", {})
self._ep_sla.append( abs(float(rb.get("sla_penalty", 0.0))))
self._ep_fair.append(abs(float(rb.get("fairness_penalty", 0.0))))
self._ep_mask_applied.append(float(bool(info.get("action_mask_applied", False))))
if done:
mean_sla = float(np.mean(self._ep_sla)) if self._ep_sla else 0.0
mean_fair = float(np.mean(self._ep_fair)) if self._ep_fair else 0.0
mask_rate = float(np.mean(self._ep_mask_applied)) if self._ep_mask_applied else 0.0
self._episode_costs.append({"sla": mean_sla, "fairness": mean_fair})
self.logger.record("costs/episode_mean_sla_penalty", mean_sla)
self.logger.record("costs/episode_mean_fairness_penalty", mean_fair)
self.logger.record("costs/sla_threshold_violated", float(mean_sla > THRESHOLD_SLA))
self.logger.record("costs/fairness_threshold_violated", float(mean_fair > THRESHOLD_FAIRNESS))
self.logger.record("costs/episode_action_mask_applied_rate", mask_rate)
self._ep_sla.clear()
self._ep_fair.clear()
self._ep_mask_applied.clear()
return True
def _on_training_end(self) -> None:
if not self._episode_costs:
return
all_sla = [c["sla"] for c in self._episode_costs]
all_fair = [c["fairness"] for c in self._episode_costs]
print(
f"\n[CostMonitor] mean SLA penalty: {np.mean(all_sla):.4f} "
f"(threshold={THRESHOLD_SLA}), "
f"mean fairness penalty: {np.mean(all_fair):.4f} "
f"(threshold={THRESHOLD_FAIRNESS})"
)
class RecurrentEvalCallback(BaseCallback):
"""
Periodic evaluation callback for RecurrentPPO.
We evaluate with deterministic inference and enforce action masks at
inference time before env.step().
"""
def __init__(
self,
eval_env: GovWorkflowGymEnv,
eval_freq: int = 2048,
n_eval_episodes: int = 3,
best_model_save_path: str = "results/best_model",
log_path: str = "results/eval_logs",
task_id: str = "mixed_urgency_medium",
verbose: int = 1,
):
super().__init__(verbose=verbose)
self.eval_env = eval_env
self.eval_freq = eval_freq
self.n_eval_episodes = n_eval_episodes
self.best_model_save_path = best_model_save_path
self.log_path = log_path
self.task_id = task_id
self._best_grader_score = -np.inf
os.makedirs(best_model_save_path, exist_ok=True)
os.makedirs(log_path, exist_ok=True)
def _on_step(self) -> bool:
if self.eval_freq <= 0 or self.n_calls % self.eval_freq != 0:
return True
mean_reward, grader_score = self._run_eval()
self.logger.record("eval/mean_reward", mean_reward)
self.logger.record("eval/grader_score", grader_score)
if grader_score > self._best_grader_score:
self._best_grader_score = grader_score
save_path = os.path.join(
self.best_model_save_path, f"best_grader_recurrent_{self.task_id}"
)
self.model.save(save_path)
if self.verbose:
print(f"[Eval] New best recurrent grader score: {grader_score:.4f} -> {save_path}")
return True
def _run_eval(self) -> tuple[float, float]:
from app.graders import grade_episode
from app.tasks import TASKS
task_cfg = TASKS.get(self.task_id)
if task_cfg is None:
return 0.0, 0.0
rewards: list[float] = []
scores: list[float] = []
for ep in range(self.n_eval_episodes):
env = GovWorkflowGymEnv(self.task_id, seed=task_cfg.seed + ep, hard_action_mask=True)
obs, _ = env.reset()
done = False
ep_reward = 0.0
lstm_state: Any = None
episode_start = np.array([True], dtype=bool)
while not done:
action, lstm_state = self.model.predict(
obs,
state=lstm_state,
episode_start=episode_start,
deterministic=True,
)
action_idx = int(np.asarray(action).item())
masks = env.action_masks()
if action_idx < 0 or action_idx >= masks.shape[0] or not bool(masks[action_idx]):
if masks.shape[0] > 18 and bool(masks[18]):
action_idx = 18
else:
valid = np.flatnonzero(masks)
if valid.size > 0:
action_idx = int(valid[0])
obs, reward, terminated, truncated, _ = env.step(action_idx)
ep_reward += float(reward)
done = bool(terminated or truncated)
episode_start = np.array([done], dtype=bool)
result = grade_episode(env._core_env.state())
rewards.append(ep_reward)
scores.append(float(result.score))
return float(np.mean(rewards)), float(np.mean(scores))