Spaces:
Sleeping
Sleeping
File size: 6,643 Bytes
225e725 7f60dea 225e725 8f836d8 225e725 8f836d8 225e725 8f836d8 225e725 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """Evaluate a checkpoint on the Prosody-Pivot Set (Option A: single-step CoT classification).
Same prompt format as `train_grpo.py`: full briefing in one prompt, model
emits <think>...</think><final>{...}</final>, we parse the label and compare
to gold. No env walking required during eval — the env is for inference/demo.
Outputs the headline submission number: "X/50" (per-clip majority across seeds).
Usage:
# Eval a trained checkpoint
python train/eval_pivot_set.py \\
--checkpoint ./checkpoints/run1 \\
--pivot data/pivot_set.json \\
--out docs/plots/pivot_results_trained.json
# Eval the BASE model as the baseline
python train/eval_pivot_set.py \\
--checkpoint baseline-only \\
--pivot data/pivot_set.json \\
--out docs/plots/pivot_results_baseline.json
"""
from __future__ import annotations
import argparse
import json
import statistics
import sys
from pathlib import Path
from typing import Any, Dict, List
try:
from subtext_arena.server.scenarios import load_scenarios
from subtext_arena.train.train_grpo import (
SYSTEM_PROMPT,
build_full_observation,
parse_final,
reward_decomposition,
)
except ImportError:
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from server.scenarios import load_scenarios # type: ignore[no-redef]
from train.train_grpo import ( # type: ignore[no-redef]
SYSTEM_PROMPT,
build_full_observation,
parse_final,
reward_decomposition,
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True,
help="Path to LoRA checkpoint, or 'baseline-only' to eval the base model.")
parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--pivot", required=True)
parser.add_argument("--out", required=True)
parser.add_argument("--n-seeds", type=int, default=3)
parser.add_argument("--max-tokens", type=int, default=768)
parser.add_argument("--temperature", type=float, default=0.7)
args = parser.parse_args()
pivot = json.loads(Path(args.pivot).read_text())
pivot_ids = pivot.get("clip_ids", [])
if not pivot_ids:
print("[eval] pivot set is empty - run curate_pivot_set.py first", file=sys.stderr)
sys.exit(1)
scenarios = load_scenarios()
# Load model
print(f"[load] {args.checkpoint}")
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=(args.base_model if args.checkpoint == "baseline-only" else args.checkpoint),
max_seq_length=4096,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
results: List[Dict[str, Any]] = []
for clip_id in pivot_ids:
gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere"
for seed in range(args.n_seeds):
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_full_observation(clip_id, scenarios)},
]
encoded = tokenizer.apply_chat_template(
messages, return_tensors="pt", add_generation_prompt=True,
)
input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
input_ids = input_ids.to(model.device)
prompt_len = input_ids.shape[1]
out = model.generate(
input_ids=input_ids,
max_new_tokens=args.max_tokens,
do_sample=True,
temperature=args.temperature,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)
decomp = reward_decomposition(text, gold)
results.append({
"clip_id": clip_id,
"seed": seed,
"gold": gold,
"predicted": decomp["_predicted"],
"confidence": decomp["_confidence"],
"correct": decomp["_correct"],
"well_formed": decomp["_well_formed"],
"reward_total": decomp["_total"],
"reward_components": {
"correctness": decomp["correctness"],
"reasoning_length": decomp["reasoning_length"],
"format": decomp["format"],
},
"completion_text": text[:1500], # first 1500 chars for inspection
})
# Aggregate
rewards = [r["reward_total"] for r in results]
n_well_formed = sum(1 for r in results if r["well_formed"])
n_correct = sum(1 for r in results if r["correct"])
accuracy = n_correct / len(results) if results else 0.0
# Per-clip mode (majority across seeds)
per_clip: Dict[str, List[bool]] = {}
for r in results:
per_clip.setdefault(r["clip_id"], []).append(r["correct"])
clip_accuracy = sum(1 for hits in per_clip.values() if sum(hits) > len(hits) / 2)
print()
print(f"[eval] checkpoint: {args.checkpoint}")
print(f"[eval] {len(results)} rollouts on {len(pivot_ids)} pivot clips")
print(f"[eval] mean reward: {statistics.mean(rewards):.3f} +/- {statistics.stdev(rewards) if len(rewards)>1 else 0.0:.3f}")
print(f"[eval] well-formed rate: {n_well_formed}/{len(results)} = {n_well_formed/len(results):.2%}")
print(f"[eval] rollout accuracy: {n_correct}/{len(results)} = {accuracy:.2%}")
print(f"[eval] clip accuracy (maj): {clip_accuracy}/{len(pivot_ids)} = {clip_accuracy/len(pivot_ids):.2%}")
print()
print(f"[eval] HEADLINE NUMBER: {clip_accuracy}/{len(pivot_ids)} on the Prosody-Pivot Set")
print()
payload = {
"checkpoint": args.checkpoint,
"n_pivot_clips": len(pivot_ids),
"n_rollouts": len(results),
"mean_reward": statistics.mean(rewards),
"stdev_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0.0,
"well_formed_rate": n_well_formed / len(results) if results else 0.0,
"rollout_accuracy": accuracy,
"clip_accuracy": clip_accuracy / len(pivot_ids) if pivot_ids else 0.0,
"headline_n_correct": clip_accuracy,
"headline_n_total": len(pivot_ids),
"rollouts": results,
}
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
Path(args.out).write_text(json.dumps(payload, indent=2))
print(f"[eval] wrote {args.out}")
if __name__ == "__main__":
main()
|