"""Evaluate a checkpoint on the Prosody-Pivot Set (Option A: single-step CoT classification).
Same prompt format as `train_grpo.py`: full briefing in one prompt, model
emits ...{...}, we parse the label and compare
to gold. No env walking required during eval — the env is for inference/demo.
Outputs the headline submission number: "X/50" (per-clip majority across seeds).
Usage:
# Eval a trained checkpoint
python train/eval_pivot_set.py \\
--checkpoint ./checkpoints/run1 \\
--pivot data/pivot_set.json \\
--out docs/plots/pivot_results_trained.json
# Eval the BASE model as the baseline
python train/eval_pivot_set.py \\
--checkpoint baseline-only \\
--pivot data/pivot_set.json \\
--out docs/plots/pivot_results_baseline.json
"""
from __future__ import annotations
import argparse
import json
import statistics
import sys
from pathlib import Path
from typing import Any, Dict, List
try:
from subtext_arena.server.scenarios import load_scenarios
from subtext_arena.train.train_grpo import (
SYSTEM_PROMPT,
build_full_observation,
parse_final,
reward_decomposition,
)
except ImportError:
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from server.scenarios import load_scenarios # type: ignore[no-redef]
from train.train_grpo import ( # type: ignore[no-redef]
SYSTEM_PROMPT,
build_full_observation,
parse_final,
reward_decomposition,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True,
help="Path to LoRA checkpoint, or 'baseline-only' to eval the base model.")
parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--pivot", required=True)
parser.add_argument("--out", required=True)
parser.add_argument("--n-seeds", type=int, default=3)
parser.add_argument("--max-tokens", type=int, default=768)
parser.add_argument("--temperature", type=float, default=0.7)
args = parser.parse_args()
pivot = json.loads(Path(args.pivot).read_text())
pivot_ids = pivot.get("clip_ids", [])
if not pivot_ids:
print("[eval] pivot set is empty - run curate_pivot_set.py first", file=sys.stderr)
sys.exit(1)
scenarios = load_scenarios()
# Load model
print(f"[load] {args.checkpoint}")
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=(args.base_model if args.checkpoint == "baseline-only" else args.checkpoint),
max_seq_length=4096,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
results: List[Dict[str, Any]] = []
for clip_id in pivot_ids:
gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere"
for seed in range(args.n_seeds):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_full_observation(clip_id, scenarios)},
]
encoded = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
)
input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
input_ids = input_ids.to(model.device)
prompt_len = input_ids.shape[1]
out = model.generate(
input_ids=input_ids,
max_new_tokens=args.max_tokens,
do_sample=True,
temperature=args.temperature,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
decomp = reward_decomposition(text, gold)
results.append({
"clip_id": clip_id,
"seed": seed,
"gold": gold,
"predicted": decomp["_predicted"],
"confidence": decomp["_confidence"],
"correct": decomp["_correct"],
"well_formed": decomp["_well_formed"],
"reward_total": decomp["_total"],
"reward_components": {
"correctness": decomp["correctness"],
"reasoning_length": decomp["reasoning_length"],
"format": decomp["format"],
},
"completion_text": text[:1500], # first 1500 chars for inspection
})
# Aggregate
rewards = [r["reward_total"] for r in results]
n_well_formed = sum(1 for r in results if r["well_formed"])
n_correct = sum(1 for r in results if r["correct"])
accuracy = n_correct / len(results) if results else 0.0
# Per-clip mode (majority across seeds)
per_clip: Dict[str, List[bool]] = {}
for r in results:
per_clip.setdefault(r["clip_id"], []).append(r["correct"])
clip_accuracy = sum(1 for hits in per_clip.values() if sum(hits) > len(hits) / 2)
print()
print(f"[eval] checkpoint: {args.checkpoint}")
print(f"[eval] {len(results)} rollouts on {len(pivot_ids)} pivot clips")
print(f"[eval] mean reward: {statistics.mean(rewards):.3f} +/- {statistics.stdev(rewards) if len(rewards)>1 else 0.0:.3f}")
print(f"[eval] well-formed rate: {n_well_formed}/{len(results)} = {n_well_formed/len(results):.2%}")
print(f"[eval] rollout accuracy: {n_correct}/{len(results)} = {accuracy:.2%}")
print(f"[eval] clip accuracy (maj): {clip_accuracy}/{len(pivot_ids)} = {clip_accuracy/len(pivot_ids):.2%}")
print()
print(f"[eval] HEADLINE NUMBER: {clip_accuracy}/{len(pivot_ids)} on the Prosody-Pivot Set")
print()
payload = {
"checkpoint": args.checkpoint,
"n_pivot_clips": len(pivot_ids),
"n_rollouts": len(results),
"mean_reward": statistics.mean(rewards),
"stdev_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0.0,
"well_formed_rate": n_well_formed / len(results) if results else 0.0,
"rollout_accuracy": accuracy,
"clip_accuracy": clip_accuracy / len(pivot_ids) if pivot_ids else 0.0,
"headline_n_correct": clip_accuracy,
"headline_n_total": len(pivot_ids),
"rollouts": results,
}
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
Path(args.out).write_text(json.dumps(payload, indent=2))
print(f"[eval] wrote {args.out}")
if __name__ == "__main__":
main()