"""Evaluate a checkpoint on the Prosody-Pivot Set (Option A: single-step CoT classification). Same prompt format as `train_grpo.py`: full briefing in one prompt, model emits ...{...}, we parse the label and compare to gold. No env walking required during eval — the env is for inference/demo. Outputs the headline submission number: "X/50" (per-clip majority across seeds). Usage: # Eval a trained checkpoint python train/eval_pivot_set.py \\ --checkpoint ./checkpoints/run1 \\ --pivot data/pivot_set.json \\ --out docs/plots/pivot_results_trained.json # Eval the BASE model as the baseline python train/eval_pivot_set.py \\ --checkpoint baseline-only \\ --pivot data/pivot_set.json \\ --out docs/plots/pivot_results_baseline.json """ from __future__ import annotations import argparse import json import statistics import sys from pathlib import Path from typing import Any, Dict, List try: from subtext_arena.server.scenarios import load_scenarios from subtext_arena.train.train_grpo import ( SYSTEM_PROMPT, build_full_observation, parse_final, reward_decomposition, ) except ImportError: ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from server.scenarios import load_scenarios # type: ignore[no-redef] from train.train_grpo import ( # type: ignore[no-redef] SYSTEM_PROMPT, build_full_observation, parse_final, reward_decomposition, ) def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True, help="Path to LoRA checkpoint, or 'baseline-only' to eval the base model.") parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct") parser.add_argument("--pivot", required=True) parser.add_argument("--out", required=True) parser.add_argument("--n-seeds", type=int, default=3) parser.add_argument("--max-tokens", type=int, default=768) parser.add_argument("--temperature", type=float, default=0.7) args = parser.parse_args() pivot = json.loads(Path(args.pivot).read_text()) pivot_ids = pivot.get("clip_ids", []) if not pivot_ids: print("[eval] pivot set is empty - run curate_pivot_set.py first", file=sys.stderr) sys.exit(1) scenarios = load_scenarios() # Load model print(f"[load] {args.checkpoint}") from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=(args.base_model if args.checkpoint == "baseline-only" else args.checkpoint), max_seq_length=4096, load_in_4bit=True, ) FastLanguageModel.for_inference(model) results: List[Dict[str, Any]] = [] for clip_id in pivot_ids: gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere" for seed in range(args.n_seeds): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_full_observation(clip_id, scenarios)}, ] encoded = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True, ) input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded input_ids = input_ids.to(model.device) prompt_len = input_ids.shape[1] out = model.generate( input_ids=input_ids, max_new_tokens=args.max_tokens, do_sample=True, temperature=args.temperature, pad_token_id=tokenizer.eos_token_id, ) text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) decomp = reward_decomposition(text, gold) results.append({ "clip_id": clip_id, "seed": seed, "gold": gold, "predicted": decomp["_predicted"], "confidence": decomp["_confidence"], "correct": decomp["_correct"], "well_formed": decomp["_well_formed"], "reward_total": decomp["_total"], "reward_components": { "correctness": decomp["correctness"], "reasoning_length": decomp["reasoning_length"], "format": decomp["format"], }, "completion_text": text[:1500], # first 1500 chars for inspection }) # Aggregate rewards = [r["reward_total"] for r in results] n_well_formed = sum(1 for r in results if r["well_formed"]) n_correct = sum(1 for r in results if r["correct"]) accuracy = n_correct / len(results) if results else 0.0 # Per-clip mode (majority across seeds) per_clip: Dict[str, List[bool]] = {} for r in results: per_clip.setdefault(r["clip_id"], []).append(r["correct"]) clip_accuracy = sum(1 for hits in per_clip.values() if sum(hits) > len(hits) / 2) print() print(f"[eval] checkpoint: {args.checkpoint}") print(f"[eval] {len(results)} rollouts on {len(pivot_ids)} pivot clips") print(f"[eval] mean reward: {statistics.mean(rewards):.3f} +/- {statistics.stdev(rewards) if len(rewards)>1 else 0.0:.3f}") print(f"[eval] well-formed rate: {n_well_formed}/{len(results)} = {n_well_formed/len(results):.2%}") print(f"[eval] rollout accuracy: {n_correct}/{len(results)} = {accuracy:.2%}") print(f"[eval] clip accuracy (maj): {clip_accuracy}/{len(pivot_ids)} = {clip_accuracy/len(pivot_ids):.2%}") print() print(f"[eval] HEADLINE NUMBER: {clip_accuracy}/{len(pivot_ids)} on the Prosody-Pivot Set") print() payload = { "checkpoint": args.checkpoint, "n_pivot_clips": len(pivot_ids), "n_rollouts": len(results), "mean_reward": statistics.mean(rewards), "stdev_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0.0, "well_formed_rate": n_well_formed / len(results) if results else 0.0, "rollout_accuracy": accuracy, "clip_accuracy": clip_accuracy / len(pivot_ids) if pivot_ids else 0.0, "headline_n_correct": clip_accuracy, "headline_n_total": len(pivot_ids), "rollouts": results, } Path(args.out).parent.mkdir(parents=True, exist_ok=True) Path(args.out).write_text(json.dumps(payload, indent=2)) print(f"[eval] wrote {args.out}") if __name__ == "__main__": main()