""" Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term). For each episode in the JSONL: 1. Replay the env from the recorded seed 2. Step with the recorded action sequence 3. Call env.record_belief(parsed_belief) at each step (using the LAST step's belief for the grader) 4. Read final_score (now under new grader) Also runs heuristic + random baselines on the same seeds for comparison, so we can directly answer: does the teacher beat heuristic+random under the new grader, and by how much? Usage: python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl """ import argparse import json import os import random import sys from collections import defaultdict from pathlib import Path from statistics import mean, median, stdev sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models import ActionType, RhythmAction from server.rhythm_environment import RhythmEnvironment, MAX_STEPS from training.dataset import heuristic_action def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict: """Replay an episode with given actions+beliefs, return final_score breakdown.""" env = RhythmEnvironment() obs = env.reset(seed=seed) for action_name, belief in zip(action_seq, belief_seq): if obs.done: break env.record_belief(belief) obs = env.step(RhythmAction(action_type=ActionType(action_name))) final_score = obs.reward_breakdown.get("final_score", 0.0) return { "final_score": final_score, "vitality": obs.vitality, "cognition": obs.cognition, "progress": obs.progress, "serenity": obs.serenity, "connection": obs.connection, } def play_heuristic(seed: int) -> float: env = RhythmEnvironment() obs = env.reset(seed=seed) for _ in range(MAX_STEPS): if obs.done: break action = heuristic_action(obs) obs = env.step(RhythmAction(action_type=action)) return obs.reward_breakdown.get("final_score", 0.0) def play_random(seed: int) -> float: env = RhythmEnvironment() obs = env.reset(seed=seed) rng = random.Random(seed + 12345) actions = list(ActionType) for _ in range(MAX_STEPS): if obs.done: break action = rng.choice(actions) obs = env.step(RhythmAction(action_type=action)) return obs.reward_breakdown.get("final_score", 0.0) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--jsonl", type=str, required=True, help="Teacher trajectories JSONL") args = parser.parse_args() # Group rows by seed by_seed: dict[int, list[dict]] = defaultdict(list) with open(args.jsonl) as f: for line in f: row = json.loads(line) by_seed[row["seed"]].append(row) for seed in by_seed: by_seed[seed].sort(key=lambda r: r["step"]) seeds = sorted(by_seed.keys()) print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n") teacher_scores = [] heuristic_scores = [] random_scores = [] for seed in seeds: rows = by_seed[seed] action_seq = [r["parsed_action"] for r in rows] belief_seq = [r["parsed_belief"] for r in rows] result = replay_with_beliefs(seed, action_seq, belief_seq) teacher_scores.append(result["final_score"]) heuristic_scores.append(play_heuristic(seed)) random_scores.append(play_random(seed)) def stats(vs, label): return f"{label:<10} mean={mean(vs):.4f} median={median(vs):.4f} std={stdev(vs):.4f} min={min(vs):.4f} max={max(vs):.4f}" print("=" * 78) print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)") print("=" * 78) print(stats(teacher_scores, "teacher:")) print(stats(heuristic_scores, "heuristic:")) print(stats(random_scores, "random:")) teacher_avg = mean(teacher_scores) heur_avg = mean(heuristic_scores) rand_avg = mean(random_scores) margin_h = teacher_avg - heur_avg margin_r = teacher_avg - rand_avg print() print(f"teacher - heuristic: {margin_h:+.4f}") print(f"teacher - random: {margin_r:+.4f}") print() # Per-episode comparison teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h) print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}") # Verdict print() print("=" * 78) print("VERDICT") print("=" * 78) GATE = 0.05 # teacher must beat heuristic by at least 5pts on average if margin_h >= GATE: print(f"PASS — teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)") print(" New grader differentiates inference from reflex. Proceed to scale + SFT.") elif margin_h > 0: print(f"WEAK — teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)") print(" Consider raising belief_accuracy weight or improving teacher prompt.") else: print(f"FAIL — teacher does NOT beat heuristic ({margin_h:.3f})") print(" Either grader weights are wrong or teacher is genuinely weak.") if __name__ == "__main__": main()