Spaces:
Sleeping
Sleeping
| """ | |
| Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term). | |
| For each episode in the JSONL: | |
| 1. Replay the env from the recorded seed | |
| 2. Step with the recorded action sequence | |
| 3. Call env.record_belief(parsed_belief) at each step (using the LAST step's | |
| belief for the grader) | |
| 4. Read final_score (now under new grader) | |
| Also runs heuristic + random baselines on the same seeds for comparison, so | |
| we can directly answer: does the teacher beat heuristic+random under the new | |
| grader, and by how much? | |
| Usage: | |
| python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from statistics import mean, median, stdev | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models import ActionType, RhythmAction | |
| from server.rhythm_environment import RhythmEnvironment, MAX_STEPS | |
| from training.dataset import heuristic_action | |
| def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict: | |
| """Replay an episode with given actions+beliefs, return final_score breakdown.""" | |
| env = RhythmEnvironment() | |
| obs = env.reset(seed=seed) | |
| for action_name, belief in zip(action_seq, belief_seq): | |
| if obs.done: | |
| break | |
| env.record_belief(belief) | |
| obs = env.step(RhythmAction(action_type=ActionType(action_name))) | |
| final_score = obs.reward_breakdown.get("final_score", 0.0) | |
| return { | |
| "final_score": final_score, | |
| "vitality": obs.vitality, | |
| "cognition": obs.cognition, | |
| "progress": obs.progress, | |
| "serenity": obs.serenity, | |
| "connection": obs.connection, | |
| } | |
| def play_heuristic(seed: int) -> float: | |
| env = RhythmEnvironment() | |
| obs = env.reset(seed=seed) | |
| for _ in range(MAX_STEPS): | |
| if obs.done: | |
| break | |
| action = heuristic_action(obs) | |
| obs = env.step(RhythmAction(action_type=action)) | |
| return obs.reward_breakdown.get("final_score", 0.0) | |
| def play_random(seed: int) -> float: | |
| env = RhythmEnvironment() | |
| obs = env.reset(seed=seed) | |
| rng = random.Random(seed + 12345) | |
| actions = list(ActionType) | |
| for _ in range(MAX_STEPS): | |
| if obs.done: | |
| break | |
| action = rng.choice(actions) | |
| obs = env.step(RhythmAction(action_type=action)) | |
| return obs.reward_breakdown.get("final_score", 0.0) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--jsonl", type=str, required=True, | |
| help="Teacher trajectories JSONL") | |
| args = parser.parse_args() | |
| # Group rows by seed | |
| by_seed: dict[int, list[dict]] = defaultdict(list) | |
| with open(args.jsonl) as f: | |
| for line in f: | |
| row = json.loads(line) | |
| by_seed[row["seed"]].append(row) | |
| for seed in by_seed: | |
| by_seed[seed].sort(key=lambda r: r["step"]) | |
| seeds = sorted(by_seed.keys()) | |
| print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n") | |
| teacher_scores = [] | |
| heuristic_scores = [] | |
| random_scores = [] | |
| for seed in seeds: | |
| rows = by_seed[seed] | |
| action_seq = [r["parsed_action"] for r in rows] | |
| belief_seq = [r["parsed_belief"] for r in rows] | |
| result = replay_with_beliefs(seed, action_seq, belief_seq) | |
| teacher_scores.append(result["final_score"]) | |
| heuristic_scores.append(play_heuristic(seed)) | |
| random_scores.append(play_random(seed)) | |
| def stats(vs, label): | |
| return f"{label:<10} mean={mean(vs):.4f} median={median(vs):.4f} std={stdev(vs):.4f} min={min(vs):.4f} max={max(vs):.4f}" | |
| print("=" * 78) | |
| print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)") | |
| print("=" * 78) | |
| print(stats(teacher_scores, "teacher:")) | |
| print(stats(heuristic_scores, "heuristic:")) | |
| print(stats(random_scores, "random:")) | |
| teacher_avg = mean(teacher_scores) | |
| heur_avg = mean(heuristic_scores) | |
| rand_avg = mean(random_scores) | |
| margin_h = teacher_avg - heur_avg | |
| margin_r = teacher_avg - rand_avg | |
| print() | |
| print(f"teacher - heuristic: {margin_h:+.4f}") | |
| print(f"teacher - random: {margin_r:+.4f}") | |
| print() | |
| # Per-episode comparison | |
| teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h) | |
| print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}") | |
| # Verdict | |
| print() | |
| print("=" * 78) | |
| print("VERDICT") | |
| print("=" * 78) | |
| GATE = 0.05 # teacher must beat heuristic by at least 5pts on average | |
| if margin_h >= GATE: | |
| print(f"PASS — teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)") | |
| print(" New grader differentiates inference from reflex. Proceed to scale + SFT.") | |
| elif margin_h > 0: | |
| print(f"WEAK — teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)") | |
| print(" Consider raising belief_accuracy weight or improving teacher prompt.") | |
| else: | |
| print(f"FAIL — teacher does NOT beat heuristic ({margin_h:.3f})") | |
| print(" Either grader weights are wrong or teacher is genuinely weak.") | |
| if __name__ == "__main__": | |
| main() | |