rhythm_env / scripts /reeval_teacher_trajectories.py
InosLihka's picture
Algorithm Distillation: grader v2 with belief_accuracy + SFT pipeline
ece0bbe
"""
Re-evaluate existing teacher trajectories under the NEW grader (with belief_accuracy term).
For each episode in the JSONL:
1. Replay the env from the recorded seed
2. Step with the recorded action sequence
3. Call env.record_belief(parsed_belief) at each step (using the LAST step's
belief for the grader)
4. Read final_score (now under new grader)
Also runs heuristic + random baselines on the same seeds for comparison, so
we can directly answer: does the teacher beat heuristic+random under the new
grader, and by how much?
Usage:
python scripts/reeval_teacher_trajectories.py --jsonl data/teacher_30ep_validation.jsonl
"""
import argparse
import json
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
from statistics import mean, median, stdev
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import ActionType, RhythmAction
from server.rhythm_environment import RhythmEnvironment, MAX_STEPS
from training.dataset import heuristic_action
def replay_with_beliefs(seed: int, action_seq: list[str], belief_seq: list[list[float]]) -> dict:
"""Replay an episode with given actions+beliefs, return final_score breakdown."""
env = RhythmEnvironment()
obs = env.reset(seed=seed)
for action_name, belief in zip(action_seq, belief_seq):
if obs.done:
break
env.record_belief(belief)
obs = env.step(RhythmAction(action_type=ActionType(action_name)))
final_score = obs.reward_breakdown.get("final_score", 0.0)
return {
"final_score": final_score,
"vitality": obs.vitality,
"cognition": obs.cognition,
"progress": obs.progress,
"serenity": obs.serenity,
"connection": obs.connection,
}
def play_heuristic(seed: int) -> float:
env = RhythmEnvironment()
obs = env.reset(seed=seed)
for _ in range(MAX_STEPS):
if obs.done:
break
action = heuristic_action(obs)
obs = env.step(RhythmAction(action_type=action))
return obs.reward_breakdown.get("final_score", 0.0)
def play_random(seed: int) -> float:
env = RhythmEnvironment()
obs = env.reset(seed=seed)
rng = random.Random(seed + 12345)
actions = list(ActionType)
for _ in range(MAX_STEPS):
if obs.done:
break
action = rng.choice(actions)
obs = env.step(RhythmAction(action_type=action))
return obs.reward_breakdown.get("final_score", 0.0)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--jsonl", type=str, required=True,
help="Teacher trajectories JSONL")
args = parser.parse_args()
# Group rows by seed
by_seed: dict[int, list[dict]] = defaultdict(list)
with open(args.jsonl) as f:
for line in f:
row = json.loads(line)
by_seed[row["seed"]].append(row)
for seed in by_seed:
by_seed[seed].sort(key=lambda r: r["step"])
seeds = sorted(by_seed.keys())
print(f"Loaded {len(seeds)} episodes from {args.jsonl}\n")
teacher_scores = []
heuristic_scores = []
random_scores = []
for seed in seeds:
rows = by_seed[seed]
action_seq = [r["parsed_action"] for r in rows]
belief_seq = [r["parsed_belief"] for r in rows]
result = replay_with_beliefs(seed, action_seq, belief_seq)
teacher_scores.append(result["final_score"])
heuristic_scores.append(play_heuristic(seed))
random_scores.append(play_random(seed))
def stats(vs, label):
return f"{label:<10} mean={mean(vs):.4f} median={median(vs):.4f} std={stdev(vs):.4f} min={min(vs):.4f} max={max(vs):.4f}"
print("=" * 78)
print("FINAL_SCORE UNDER NEW GRADER (belief_accuracy weighted 0.20)")
print("=" * 78)
print(stats(teacher_scores, "teacher:"))
print(stats(heuristic_scores, "heuristic:"))
print(stats(random_scores, "random:"))
teacher_avg = mean(teacher_scores)
heur_avg = mean(heuristic_scores)
rand_avg = mean(random_scores)
margin_h = teacher_avg - heur_avg
margin_r = teacher_avg - rand_avg
print()
print(f"teacher - heuristic: {margin_h:+.4f}")
print(f"teacher - random: {margin_r:+.4f}")
print()
# Per-episode comparison
teacher_beats_heur = sum(1 for t, h in zip(teacher_scores, heuristic_scores) if t > h)
print(f"Episodes where teacher > heuristic: {teacher_beats_heur} / {len(seeds)}")
# Verdict
print()
print("=" * 78)
print("VERDICT")
print("=" * 78)
GATE = 0.05 # teacher must beat heuristic by at least 5pts on average
if margin_h >= GATE:
print(f"PASS — teacher beats heuristic by {margin_h:.3f} (>= {GATE} required)")
print(" New grader differentiates inference from reflex. Proceed to scale + SFT.")
elif margin_h > 0:
print(f"WEAK — teacher beats heuristic by only {margin_h:.3f} (need {GATE}+)")
print(" Consider raising belief_accuracy weight or improving teacher prompt.")
else:
print(f"FAIL — teacher does NOT beat heuristic ({margin_h:.3f})")
print(" Either grader weights are wrong or teacher is genuinely weak.")
if __name__ == "__main__":
main()