""" Numerical equivalence check for the Rubric refactor. Replays each episode from a saved eval JSON (which was scored by the OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric- based grader). If `final_score` matches for every episode within float precision, the refactor is functionally identical to the original. Usage: python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json """ import argparse import json import os import random import sys from pathlib import Path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models import ActionType, RhythmAction from server.rhythm_environment import MAX_STEPS, RhythmEnvironment def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None, profile: str | None) -> float: """Replay one episode locally under the new grader.""" env = RhythmEnvironment() if profile: obs = env.reset(seed=seed, profile=profile) else: obs = env.reset(seed=seed) for i, action_name in enumerate(actions): if obs.done: break # Set the final belief on the LAST step (matches inference_eval.py # behavior: record_belief was called every step but only the latest # matters for the grader) is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1) if is_last and final_belief is not None: env.record_belief(final_belief) rhythm_action = RhythmAction(action_type=ActionType(action_name)) obs = env.step(rhythm_action) return obs.reward_breakdown.get("final_score", 0.0) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("eval_json", help="Path to eval_results JSON to verify against") ap.add_argument("--tolerance", type=float, default=1e-4, help="Allowed |new - old| difference (default 1e-4)") args = ap.parse_args() with open(args.eval_json) as f: rows = json.load(f) print(f"Loaded {len(rows)} episodes from {args.eval_json}") print() matches = 0 mismatches = [] for row in rows: seed = row["seed"] actions = row.get("actions", []) final_belief = row.get("final_belief") # null for heuristic/random profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None # Some rows have profile_mode='discrete' with explicit profile # name (the 3 reference profiles). Pass it via kwarg. if profile and not profile.startswith("sampled_"): replay_profile = profile else: replay_profile = None old_score = row["final_score"] new_score = replay_episode(seed, actions, final_belief, replay_profile) diff = abs(new_score - old_score) if diff <= args.tolerance: matches += 1 else: mismatches.append({ "seed": seed, "strategy": row["strategy"], "condition": row["condition"], "old": old_score, "new": new_score, "diff": diff, }) print(f"=" * 60) print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}") print(f"=" * 60) if mismatches: print() print(f"MISMATCHES ({len(mismatches)}):") for m in mismatches[:10]: print(f" seed={m['seed']:>5} {m['strategy']:>10} " f"{m['condition']:<35} old={m['old']:.4f} " f"new={m['new']:.4f} diff={m['diff']:.4f}") if len(mismatches) > 10: print(f" ... and {len(mismatches) - 10} more") sys.exit(1) else: print() print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.") if __name__ == "__main__": main()