File size: 3,861 Bytes
666b4ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Numerical equivalence check for the Rubric refactor.

Replays each episode from a saved eval JSON (which was scored by the
OLD `_grade_episode`) under the LOCAL code (which has the NEW Rubric-
based grader). If `final_score` matches for every episode within float
precision, the refactor is functionally identical to the original.

Usage:
    python scripts/verify_rubric_equivalence.py outputs/sft-v3/eval_results_v2.json
"""

import argparse
import json
import os
import random
import sys
from pathlib import Path

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models import ActionType, RhythmAction
from server.rhythm_environment import MAX_STEPS, RhythmEnvironment


def replay_episode(seed: int, actions: list[str], final_belief: list[float] | None,
                   profile: str | None) -> float:
    """Replay one episode locally under the new grader."""
    env = RhythmEnvironment()
    if profile:
        obs = env.reset(seed=seed, profile=profile)
    else:
        obs = env.reset(seed=seed)

    for i, action_name in enumerate(actions):
        if obs.done:
            break
        # Set the final belief on the LAST step (matches inference_eval.py
        # behavior: record_belief was called every step but only the latest
        # matters for the grader)
        is_last = (i == len(actions) - 1) or (i == MAX_STEPS - 1)
        if is_last and final_belief is not None:
            env.record_belief(final_belief)
        rhythm_action = RhythmAction(action_type=ActionType(action_name))
        obs = env.step(rhythm_action)

    return obs.reward_breakdown.get("final_score", 0.0)


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("eval_json", help="Path to eval_results JSON to verify against")
    ap.add_argument("--tolerance", type=float, default=1e-4,
                    help="Allowed |new - old| difference (default 1e-4)")
    args = ap.parse_args()

    with open(args.eval_json) as f:
        rows = json.load(f)
    print(f"Loaded {len(rows)} episodes from {args.eval_json}")
    print()

    matches = 0
    mismatches = []
    for row in rows:
        seed = row["seed"]
        actions = row.get("actions", [])
        final_belief = row.get("final_belief")  # null for heuristic/random
        profile = row.get("profile_name") if row.get("profile_mode") == "discrete" else None
        # Some rows have profile_mode='discrete' with explicit profile
        # name (the 3 reference profiles). Pass it via kwarg.
        if profile and not profile.startswith("sampled_"):
            replay_profile = profile
        else:
            replay_profile = None

        old_score = row["final_score"]
        new_score = replay_episode(seed, actions, final_belief, replay_profile)

        diff = abs(new_score - old_score)
        if diff <= args.tolerance:
            matches += 1
        else:
            mismatches.append({
                "seed": seed,
                "strategy": row["strategy"],
                "condition": row["condition"],
                "old": old_score,
                "new": new_score,
                "diff": diff,
            })

    print(f"=" * 60)
    print(f"RESULT: {matches}/{len(rows)} episodes match within ±{args.tolerance}")
    print(f"=" * 60)
    if mismatches:
        print()
        print(f"MISMATCHES ({len(mismatches)}):")
        for m in mismatches[:10]:
            print(f"  seed={m['seed']:>5}  {m['strategy']:>10}  "
                  f"{m['condition']:<35}  old={m['old']:.4f}  "
                  f"new={m['new']:.4f}  diff={m['diff']:.4f}")
        if len(mismatches) > 10:
            print(f"  ... and {len(mismatches) - 10} more")
        sys.exit(1)
    else:
        print()
        print("REFACTOR IS NUMERICALLY EQUIVALENT to the old grader.")


if __name__ == "__main__":
    main()