File size: 5,313 Bytes
ece0bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term).

For each episode in the JSONL:
  1. Replay the env from the recorded seed
  2. Step with the recorded action sequence
  3. Call env.record_belief(parsed_belief) at each step (using the LAST step's
     belief for the grader)
  4. Read final_score (now under new grader)

Also runs heuristic + random baselines on the same seeds for comparison, so
we can directly answer: does the teacher beat heuristic+random under the new
grader, and by how much?

Usage:
    python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl
"""

import argparse
import json
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from statistics import mean, median, stdev

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS
from training.dataset import heuristic_action


def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict:
    """Replay an episode with given actions+beliefs, return final_score breakdown."""
    env = RhythmEnvironment()
    obs = env.reset(seed=seed)
    for action_name, belief in zip(action_seq, belief_seq):
        if obs.done:
            break
        env.record_belief(belief)
        obs = env.step(RhythmAction(action_type=ActionType(action_name)))
    final_score = obs.reward_breakdown.get("final_score", 0.0)
    return {
        "final_score": final_score,
        "vitality": obs.vitality,
        "cognition": obs.cognition,
        "progress": obs.progress,
        "serenity": obs.serenity,
        "connection": obs.connection,
    }


def play_heuristic(seed: int) -> float:
    env = RhythmEnvironment()
    obs = env.reset(seed=seed)
    for _ in range(MAX_STEPS):
        if obs.done:
            break
        action = heuristic_action(obs)
        obs = env.step(RhythmAction(action_type=action))
    return obs.reward_breakdown.get("final_score", 0.0)


def play_random(seed: int) -> float:
    env = RhythmEnvironment()
    obs = env.reset(seed=seed)
    rng = random.Random(seed + 12345)
    actions = list(ActionType)
    for _ in range(MAX_STEPS):
        if obs.done:
            break
        action = rng.choice(actions)
        obs = env.step(RhythmAction(action_type=action))
    return obs.reward_breakdown.get("final_score", 0.0)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--jsonl", type=str, required=True,
                        help="Teacher trajectories JSONL")
    args = parser.parse_args()

    # Group rows by seed
    by_seed: dict[int, list[dict]] = defaultdict(list)
    with open(args.jsonl) as f:
        for line in f:
            row = json.loads(line)
            by_seed[row["seed"]].append(row)
    for seed in by_seed:
        by_seed[seed].sort(key=lambda r: r["step"])

    seeds = sorted(by_seed.keys())
    print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n")

    teacher_scores = []
    heuristic_scores = []
    random_scores = []

    for seed in seeds:
        rows = by_seed[seed]
        action_seq = [r["parsed_action"] for r in rows]
        belief_seq = [r["parsed_belief"] for r in rows]

        result = replay_with_beliefs(seed, action_seq, belief_seq)
        teacher_scores.append(result["final_score"])
        heuristic_scores.append(play_heuristic(seed))
        random_scores.append(play_random(seed))

    def stats(vs, label):
        return f"{label:<10} mean={mean(vs):.4f}  median={median(vs):.4f}  std={stdev(vs):.4f}  min={min(vs):.4f}  max={max(vs):.4f}"

    print("=" * 78)
    print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)")
    print("=" * 78)
    print(stats(teacher_scores, "teacher:"))
    print(stats(heuristic_scores, "heuristic:"))
    print(stats(random_scores, "random:"))

    teacher_avg = mean(teacher_scores)
    heur_avg = mean(heuristic_scores)
    rand_avg = mean(random_scores)
    margin_h = teacher_avg - heur_avg
    margin_r = teacher_avg - rand_avg

    print()
    print(f"teacher - heuristic: {margin_h:+.4f}")
    print(f"teacher - random:    {margin_r:+.4f}")
    print()

    # Per-episode comparison
    teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h)
    print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}")

    # Verdict
    print()
    print("=" * 78)
    print("VERDICT")
    print("=" * 78)
    GATE = 0.05  # teacher must beat heuristic by at least 5pts on average
    if margin_h >= GATE:
        print(f"PASS — teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)")
        print("       New grader differentiates inference from reflex. Proceed to scale + SFT.")
    elif margin_h > 0:
        print(f"WEAK — teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)")
        print("       Consider raising belief_accuracy weight or improving teacher prompt.")
    else:
        print(f"FAIL — teacher does NOT beat heuristic ({margin_h:.3f})")
        print("       Either grader weights are wrong or teacher is genuinely weak.")


if __name__ == "__main__":
    main()