rhythm_env / scripts /verify_rubric_equivalence.py
InosLihka's picture
Add SFT v3 + GRPO refine results to README + results.md
666b4ce
"""
Numerical equivalence check for the Rubric refactor.
Replays each episode from a saved eval JSON (which was scored by the
OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric-
based grader). If `final_score` matches for every episode within float
precision, the refactor is functionally identical to the original.
Usage:
python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json
"""
import argparse
import json
import os
import random
import sys
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import ActionType, RhythmAction
from server.rhythm_environment import MAX_STEPS, RhythmEnvironment
def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None,
profile: str | None) -> float:
"""Replay one episode locally under the new grader."""
env = RhythmEnvironment()
if profile:
obs = env.reset(seed=seed, profile=profile)
else:
obs = env.reset(seed=seed)
for i, action_name in enumerate(actions):
if obs.done:
break
# Set the final belief on the LAST step (matches inference_eval.py
# behavior: record_belief was called every step but only the latest
# matters for the grader)
is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1)
if is_last and final_belief is not None:
env.record_belief(final_belief)
rhythm_action = RhythmAction(action_type=ActionType(action_name))
obs = env.step(rhythm_action)
return obs.reward_breakdown.get("final_score", 0.0)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("eval_json", help="Path to eval_results JSON to verify against")
ap.add_argument("--tolerance", type=float, default=1e-4,
help="Allowed |new - old| difference (default 1e-4)")
args = ap.parse_args()
with open(args.eval_json) as f:
rows = json.load(f)
print(f"Loaded {len(rows)} episodes from {args.eval_json}")
print()
matches = 0
mismatches = []
for row in rows:
seed = row["seed"]
actions = row.get("actions", [])
final_belief = row.get("final_belief") # null for heuristic/random
profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None
# Some rows have profile_mode='discrete' with explicit profile
# name (the 3 reference profiles). Pass it via kwarg.
if profile and not profile.startswith("sampled_"):
replay_profile = profile
else:
replay_profile = None
old_score = row["final_score"]
new_score = replay_episode(seed, actions, final_belief, replay_profile)
diff = abs(new_score - old_score)
if diff <= args.tolerance:
matches += 1
else:
mismatches.append({
"seed": seed,
"strategy": row["strategy"],
"condition": row["condition"],
"old": old_score,
"new": new_score,
"diff": diff,
})
print(f"=" * 60)
print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}")
print(f"=" * 60)
if mismatches:
print()
print(f"MISMATCHES ({len(mismatches)}):")
for m in mismatches[:10]:
print(f" seed={m['seed']:>5} {m['strategy']:>10} "
f"{m['condition']:<35} old={m['old']:.4f} "
f"new={m['new']:.4f} diff={m['diff']:.4f}")
if len(mismatches) > 10:
print(f" ... and {len(mismatches) - 10} more")
sys.exit(1)
else:
print()
print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.")
if __name__ == "__main__":
main()