Spaces:
Sleeping
Sleeping
File size: 5,313 Bytes
ece0bbe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term).
For each episode in the JSONL:
1. Replay the env from the recorded seed
2. Step with the recorded action sequence
3. Call env.record_belief(parsed_belief) at each step (using the LAST step's
belief for the grader)
4. Read final_score (now under new grader)
Also runs heuristic + random baselines on the same seeds for comparison, so
we can directly answer: does the teacher beat heuristic+random under the new
grader, and by how much?
Usage:
python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl
"""
import argparse
import json
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from statistics import mean, median, stdev
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS
from training.dataset import heuristic_action
def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict:
"""Replay an episode with given actions+beliefs, return final_score breakdown."""
env = RhythmEnvironment()
obs = env.reset(seed=seed)
for action_name, belief in zip(action_seq, belief_seq):
if obs.done:
break
env.record_belief(belief)
obs = env.step(RhythmAction(action_type=ActionType(action_name)))
final_score = obs.reward_breakdown.get("final_score", 0.0)
return {
"final_score": final_score,
"vitality": obs.vitality,
"cognition": obs.cognition,
"progress": obs.progress,
"serenity": obs.serenity,
"connection": obs.connection,
}
def play_heuristic(seed: int) -> float:
env = RhythmEnvironment()
obs = env.reset(seed=seed)
for _ in range(MAX_STEPS):
if obs.done:
break
action = heuristic_action(obs)
obs = env.step(RhythmAction(action_type=action))
return obs.reward_breakdown.get("final_score", 0.0)
def play_random(seed: int) -> float:
env = RhythmEnvironment()
obs = env.reset(seed=seed)
rng = random.Random(seed + 12345)
actions = list(ActionType)
for _ in range(MAX_STEPS):
if obs.done:
break
action = rng.choice(actions)
obs = env.step(RhythmAction(action_type=action))
return obs.reward_breakdown.get("final_score", 0.0)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--jsonl", type=str, required=True,
help="Teacher trajectories JSONL")
args = parser.parse_args()
# Group rows by seed
by_seed: dict[int, list[dict]] = defaultdict(list)
with open(args.jsonl) as f:
for line in f:
row = json.loads(line)
by_seed[row["seed"]].append(row)
for seed in by_seed:
by_seed[seed].sort(key=lambda r: r["step"])
seeds = sorted(by_seed.keys())
print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n")
teacher_scores = []
heuristic_scores = []
random_scores = []
for seed in seeds:
rows = by_seed[seed]
action_seq = [r["parsed_action"] for r in rows]
belief_seq = [r["parsed_belief"] for r in rows]
result = replay_with_beliefs(seed, action_seq, belief_seq)
teacher_scores.append(result["final_score"])
heuristic_scores.append(play_heuristic(seed))
random_scores.append(play_random(seed))
def stats(vs, label):
return f"{label:<10} mean={mean(vs):.4f} median={median(vs):.4f} std={stdev(vs):.4f} min={min(vs):.4f} max={max(vs):.4f}"
print("=" * 78)
print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)")
print("=" * 78)
print(stats(teacher_scores, "teacher:"))
print(stats(heuristic_scores, "heuristic:"))
print(stats(random_scores, "random:"))
teacher_avg = mean(teacher_scores)
heur_avg = mean(heuristic_scores)
rand_avg = mean(random_scores)
margin_h = teacher_avg - heur_avg
margin_r = teacher_avg - rand_avg
print()
print(f"teacher - heuristic: {margin_h:+.4f}")
print(f"teacher - random: {margin_r:+.4f}")
print()
# Per-episode comparison
teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h)
print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}")
# Verdict
print()
print("=" * 78)
print("VERDICT")
print("=" * 78)
GATE = 0.05 # teacher must beat heuristic by at least 5pts on average
if margin_h >= GATE:
print(f"PASS — teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)")
print(" New grader differentiates inference from reflex. Proceed to scale + SFT.")
elif margin_h > 0:
print(f"WEAK — teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)")
print(" Consider raising belief_accuracy weight or improving teacher prompt.")
else:
print(f"FAIL — teacher does NOT beat heuristic ({margin_h:.3f})")
print(" Either grader weights are wrong or teacher is genuinely weak.")
if __name__ == "__main__":
main()
|