Spaces:
Sleeping
Sleeping
| """Evaluate a checkpoint on the Prosody-Pivot Set (Option A: single-step CoT classification). | |
| Same prompt format as `train_grpo.py`: full briefing in one prompt, model | |
| emits <think>...</think><final>{...}</final>, we parse the label and compare | |
| to gold. No env walking required during eval — the env is for inference/demo. | |
| Outputs the headline submission number: "X/50" (per-clip majority across seeds). | |
| Usage: | |
| # Eval a trained checkpoint | |
| python train/eval_pivot_set.py \\ | |
| --checkpoint ./checkpoints/run1 \\ | |
| --pivot data/pivot_set.json \\ | |
| --out docs/plots/pivot_results_trained.json | |
| # Eval the BASE model as the baseline | |
| python train/eval_pivot_set.py \\ | |
| --checkpoint baseline-only \\ | |
| --pivot data/pivot_set.json \\ | |
| --out docs/plots/pivot_results_baseline.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| try: | |
| from subtext_arena.server.scenarios import load_scenarios | |
| from subtext_arena.train.train_grpo import ( | |
| SYSTEM_PROMPT, | |
| build_full_observation, | |
| parse_final, | |
| reward_decomposition, | |
| ) | |
| except ImportError: | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from server.scenarios import load_scenarios # type: ignore[no-redef] | |
| from train.train_grpo import ( # type: ignore[no-redef] | |
| SYSTEM_PROMPT, | |
| build_full_observation, | |
| parse_final, | |
| reward_decomposition, | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", required=True, | |
| help="Path to LoRA checkpoint, or 'baseline-only' to eval the base model.") | |
| parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct") | |
| parser.add_argument("--pivot", required=True) | |
| parser.add_argument("--out", required=True) | |
| parser.add_argument("--n-seeds", type=int, default=3) | |
| parser.add_argument("--max-tokens", type=int, default=768) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| args = parser.parse_args() | |
| pivot = json.loads(Path(args.pivot).read_text()) | |
| pivot_ids = pivot.get("clip_ids", []) | |
| if not pivot_ids: | |
| print("[eval] pivot set is empty - run curate_pivot_set.py first", file=sys.stderr) | |
| sys.exit(1) | |
| scenarios = load_scenarios() | |
| # Load model | |
| print(f"[load] {args.checkpoint}") | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=(args.base_model if args.checkpoint == "baseline-only" else args.checkpoint), | |
| max_seq_length=4096, | |
| load_in_4bit=True, | |
| ) | |
| FastLanguageModel.for_inference(model) | |
| results: List[Dict[str, Any]] = [] | |
| for clip_id in pivot_ids: | |
| gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere" | |
| for seed in range(args.n_seeds): | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_full_observation(clip_id, scenarios)}, | |
| ] | |
| encoded = tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True, | |
| ) | |
| input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded | |
| input_ids = input_ids.to(model.device) | |
| prompt_len = input_ids.shape[1] | |
| out = model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=args.max_tokens, | |
| do_sample=True, | |
| temperature=args.temperature, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True) | |
| decomp = reward_decomposition(text, gold) | |
| results.append({ | |
| "clip_id": clip_id, | |
| "seed": seed, | |
| "gold": gold, | |
| "predicted": decomp["_predicted"], | |
| "confidence": decomp["_confidence"], | |
| "correct": decomp["_correct"], | |
| "well_formed": decomp["_well_formed"], | |
| "reward_total": decomp["_total"], | |
| "reward_components": { | |
| "correctness": decomp["correctness"], | |
| "reasoning_length": decomp["reasoning_length"], | |
| "format": decomp["format"], | |
| }, | |
| "completion_text": text[:1500], # first 1500 chars for inspection | |
| }) | |
| # Aggregate | |
| rewards = [r["reward_total"] for r in results] | |
| n_well_formed = sum(1 for r in results if r["well_formed"]) | |
| n_correct = sum(1 for r in results if r["correct"]) | |
| accuracy = n_correct / len(results) if results else 0.0 | |
| # Per-clip mode (majority across seeds) | |
| per_clip: Dict[str, List[bool]] = {} | |
| for r in results: | |
| per_clip.setdefault(r["clip_id"], []).append(r["correct"]) | |
| clip_accuracy = sum(1 for hits in per_clip.values() if sum(hits) > len(hits) / 2) | |
| print() | |
| print(f"[eval] checkpoint: {args.checkpoint}") | |
| print(f"[eval] {len(results)} rollouts on {len(pivot_ids)} pivot clips") | |
| print(f"[eval] mean reward: {statistics.mean(rewards):.3f} +/- {statistics.stdev(rewards) if len(rewards)>1 else 0.0:.3f}") | |
| print(f"[eval] well-formed rate: {n_well_formed}/{len(results)} = {n_well_formed/len(results):.2%}") | |
| print(f"[eval] rollout accuracy: {n_correct}/{len(results)} = {accuracy:.2%}") | |
| print(f"[eval] clip accuracy (maj): {clip_accuracy}/{len(pivot_ids)} = {clip_accuracy/len(pivot_ids):.2%}") | |
| print() | |
| print(f"[eval] HEADLINE NUMBER: {clip_accuracy}/{len(pivot_ids)} on the Prosody-Pivot Set") | |
| print() | |
| payload = { | |
| "checkpoint": args.checkpoint, | |
| "n_pivot_clips": len(pivot_ids), | |
| "n_rollouts": len(results), | |
| "mean_reward": statistics.mean(rewards), | |
| "stdev_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0.0, | |
| "well_formed_rate": n_well_formed / len(results) if results else 0.0, | |
| "rollout_accuracy": accuracy, | |
| "clip_accuracy": clip_accuracy / len(pivot_ids) if pivot_ids else 0.0, | |
| "headline_n_correct": clip_accuracy, | |
| "headline_n_total": len(pivot_ids), | |
| "rollouts": results, | |
| } | |
| Path(args.out).parent.mkdir(parents=True, exist_ok=True) | |
| Path(args.out).write_text(json.dumps(payload, indent=2)) | |
| print(f"[eval] wrote {args.out}") | |
| if __name__ == "__main__": | |
| main() | |