Spaces:
Sleeping
Sleeping
File size: 3,861 Bytes
666b4ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
Numerical equivalence check for the Rubric refactor.
Replays each episode from a saved eval JSON (which was scored by the
OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric-
based grader). If `final_score` matches for every episode within float
precision, the refactor is functionally identical to the original.
Usage:
python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json
"""
import argparse
import json
import os
import random
import sys
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import ActionType, RhythmAction
from server.rhythm_environment import MAX_STEPS, RhythmEnvironment
def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None,
profile: str | None) -> float:
"""Replay one episode locally under the new grader."""
env = RhythmEnvironment()
if profile:
obs = env.reset(seed=seed, profile=profile)
else:
obs = env.reset(seed=seed)
for i, action_name in enumerate(actions):
if obs.done:
break
# Set the final belief on the LAST step (matches inference_eval.py
# behavior: record_belief was called every step but only the latest
# matters for the grader)
is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1)
if is_last and final_belief is not None:
env.record_belief(final_belief)
rhythm_action = RhythmAction(action_type=ActionType(action_name))
obs = env.step(rhythm_action)
return obs.reward_breakdown.get("final_score", 0.0)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("eval_json", help="Path to eval_results JSON to verify against")
ap.add_argument("--tolerance", type=float, default=1e-4,
help="Allowed |new - old| difference (default 1e-4)")
args = ap.parse_args()
with open(args.eval_json) as f:
rows = json.load(f)
print(f"Loaded {len(rows)} episodes from {args.eval_json}")
print()
matches = 0
mismatches = []
for row in rows:
seed = row["seed"]
actions = row.get("actions", [])
final_belief = row.get("final_belief") # null for heuristic/random
profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None
# Some rows have profile_mode='discrete' with explicit profile
# name (the 3 reference profiles). Pass it via kwarg.
if profile and not profile.startswith("sampled_"):
replay_profile = profile
else:
replay_profile = None
old_score = row["final_score"]
new_score = replay_episode(seed, actions, final_belief, replay_profile)
diff = abs(new_score - old_score)
if diff <= args.tolerance:
matches += 1
else:
mismatches.append({
"seed": seed,
"strategy": row["strategy"],
"condition": row["condition"],
"old": old_score,
"new": new_score,
"diff": diff,
})
print(f"=" * 60)
print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}")
print(f"=" * 60)
if mismatches:
print()
print(f"MISMATCHES ({len(mismatches)}):")
for m in mismatches[:10]:
print(f" seed={m['seed']:>5} {m['strategy']:>10} "
f"{m['condition']:<35} old={m['old']:.4f} "
f"new={m['new']:.4f} diff={m['diff']:.4f}")
if len(mismatches) > 10:
print(f" ... and {len(mismatches) - 10} more")
sys.exit(1)
else:
print()
print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.")
if __name__ == "__main__":
main()
|