OPENENV_RL_01 / rl /eval_grader.py
Siddharaj Shirke
deploy: fresh snapshot to Hugging Face Space
3eae4cc
"""
Grader-based evaluation utility for trained RL checkpoints.
This complements `rl/evaluate.py`:
- `rl/evaluate.py` is batch-oriented and returns aggregate task rows.
- `rl/eval_grader.py` is phase/task-oriented and prints per-episode progress,
promotion guidance, and an optional score/reward plot.
"""
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
from typing import Any, Literal
import matplotlib
import numpy as np
from sb3_contrib import MaskablePPO, RecurrentPPO
# Allow running as `python rl/eval_grader.py ...` from repo root.
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
from app.graders import grade_episode
from rl.gov_workflow_env import GovWorkflowGymEnv
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ModelType = Literal["auto", "maskable", "recurrent"]
PROMOTION_THRESHOLDS = {
"district_backlog_easy": 0.75,
"mixed_urgency_medium": 0.65,
"cross_department_hard": 0.55,
}
PHASE_LABELS = {
"district_backlog_easy": "Phase 1",
"mixed_urgency_medium": "Phase 2",
"cross_department_hard": "Phase 3",
}
def _normalize_action(action: Any) -> int:
if isinstance(action, np.ndarray):
return int(action.item())
return int(action)
def _sanitize_action(action_idx: int, masks: np.ndarray) -> int:
if 0 <= action_idx < masks.shape[0] and bool(masks[action_idx]):
return int(action_idx)
if masks.shape[0] > 18 and bool(masks[18]):
return 18
valid = np.flatnonzero(masks)
return int(valid[0]) if valid.size > 0 else 18
def _load_model(model_path: str, model_type: ModelType) -> tuple[Any, str]:
if model_type == "maskable":
return MaskablePPO.load(model_path), "maskable"
if model_type == "recurrent":
return RecurrentPPO.load(model_path), "recurrent"
try:
return MaskablePPO.load(model_path), "maskable"
except Exception:
return RecurrentPPO.load(model_path), "recurrent"
def evaluate_with_grader(
model_path: str,
task_id: str,
n_episodes: int = 20,
seed: int = 42,
model_type: ModelType = "auto",
save_plot: bool = True,
) -> float:
if task_id not in PROMOTION_THRESHOLDS:
allowed = ", ".join(PROMOTION_THRESHOLDS.keys())
raise ValueError(f"Unknown task_id '{task_id}'. Allowed: {allowed}")
model, resolved_type = _load_model(model_path, model_type)
print("\n" + "=" * 64)
print(f"Track A Evaluation - {PHASE_LABELS.get(task_id, task_id)}")
print(f"Model: {model_path}")
print(f"Model type: {resolved_type}")
print(f"Task: {task_id}")
print(f"Episodes: {n_episodes}")
print("=" * 64 + "\n")
scores: list[float] = []
rewards: list[float] = []
for ep in range(n_episodes):
env = GovWorkflowGymEnv(task_id=task_id, seed=seed + ep, hard_action_mask=True)
obs, _ = env.reset(seed=seed + ep)
done = False
ep_reward = 0.0
lstm_state: Any = None
episode_start = np.array([True], dtype=bool)
while not done:
masks = env.action_masks()
if resolved_type == "recurrent":
action, lstm_state = model.predict(
obs,
state=lstm_state,
episode_start=episode_start,
deterministic=True,
)
action_idx = _sanitize_action(_normalize_action(action), masks)
else:
action, _ = model.predict(obs, action_masks=masks, deterministic=True)
action_idx = _normalize_action(action)
obs, reward, terminated, truncated, _ = env.step(action_idx)
ep_reward += float(reward)
done = bool(terminated or truncated)
episode_start = np.array([done], dtype=bool)
result = grade_episode(env.core_env.state())
score = float(result.score)
threshold = float(PROMOTION_THRESHOLDS[task_id])
badge = "PASS" if score >= threshold else "FAIL"
print(f" {badge:4} ep={ep + 1:02d} score={score:.4f} reward={ep_reward:.2f}")
scores.append(score)
rewards.append(ep_reward)
mean_score = float(np.mean(scores)) if scores else 0.0
threshold = float(PROMOTION_THRESHOLDS[task_id])
print("\n" + "-" * 64)
print(f"Mean grader score: {mean_score:.4f}")
print(f"Promotion target : {threshold:.2f}")
print(f"Min / Max : {float(np.min(scores)):.4f} / {float(np.max(scores)):.4f}")
print(f"Pass rate : {sum(s >= threshold for s in scores)}/{len(scores)}")
if mean_score >= threshold:
print("Decision : PROMOTE")
else:
print("Decision : CONTINUE TRAINING")
print("=" * 64)
if save_plot:
_save_plot(scores=scores, rewards=rewards, task_id=task_id, mean_score=mean_score, threshold=threshold, model_path=model_path)
return mean_score
def _save_plot(
*,
scores: list[float],
rewards: list[float],
task_id: str,
mean_score: float,
threshold: float,
model_path: str,
) -> str:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(
f"Track A - {PHASE_LABELS.get(task_id, task_id)} Evaluation\n"
f"Task: {task_id} | Model: {os.path.basename(model_path)}",
fontsize=12,
fontweight="bold",
)
episodes = list(range(1, len(scores) + 1))
ax1 = axes[0]
colors = ["#0e8a16" if s >= threshold else "#b60205" for s in scores]
ax1.bar(episodes, scores, color=colors, alpha=0.85)
ax1.axhline(y=threshold, color="#d97706", linestyle="--", linewidth=2, label=f"threshold={threshold:.2f}")
ax1.axhline(y=mean_score, color="#1d4ed8", linestyle="-", linewidth=2, label=f"mean={mean_score:.3f}")
ax1.set_ylim(0.0, 1.05)
ax1.set_xlabel("Episode")
ax1.set_ylabel("Grader Score")
ax1.set_title("Per-Episode Grader Score")
ax1.grid(True, alpha=0.3, axis="y")
ax1.legend()
ax2 = axes[1]
ax2.plot(episodes, rewards, color="#0369a1", linewidth=2, marker="o", markersize=4)
if rewards:
mean_reward = float(np.mean(rewards))
ax2.axhline(y=mean_reward, color="#d97706", linestyle="--", linewidth=2, label=f"mean={mean_reward:.2f}")
ax2.set_xlabel("Episode")
ax2.set_ylabel("Total Reward")
ax2.set_title("Episode Reward")
ax2.grid(True, alpha=0.3)
ax2.legend()
plt.tight_layout()
out_dir = os.path.join("results", "eval_logs", task_id)
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, f"{task_id}_grader_eval.png")
plt.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Plot saved -> {out_path}")
return out_path
def main() -> None:
parser = argparse.ArgumentParser(description="Task-oriented grader evaluation for a trained checkpoint")
parser.add_argument("--model", required=True, help="Path to .zip checkpoint (suffix optional)")
parser.add_argument("--task", required=True, choices=list(PROMOTION_THRESHOLDS.keys()))
parser.add_argument("--episodes", type=int, default=20)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--model-type", choices=["auto", "maskable", "recurrent"], default="auto")
parser.add_argument("--no-plot", action="store_true", help="Disable PNG output")
args = parser.parse_args()
model_path = args.model if args.model.endswith(".zip") else f"{args.model}.zip"
evaluate_with_grader(
model_path=model_path,
task_id=args.task,
n_episodes=args.episodes,
seed=args.seed,
model_type=args.model_type,
save_plot=not args.no_plot,
)
if __name__ == "__main__":
main()