File size: 6,643 Bytes
225e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f60dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f836d8
 
 
 
 
 
225e725
8f836d8
225e725
 
 
 
 
8f836d8
225e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Evaluate a checkpoint on the Prosody-Pivot Set (Option A: single-step CoT classification).

Same prompt format as `train_grpo.py`: full briefing in one prompt, model
emits <think>...</think><final>{...}</final>, we parse the label and compare
to gold. No env walking required during eval — the env is for inference/demo.

Outputs the headline submission number: "X/50" (per-clip majority across seeds).

Usage:
    # Eval a trained checkpoint
    python train/eval_pivot_set.py \\
        --checkpoint ./checkpoints/run1 \\
        --pivot data/pivot_set.json \\
        --out docs/plots/pivot_results_trained.json

    # Eval the BASE model as the baseline
    python train/eval_pivot_set.py \\
        --checkpoint baseline-only \\
        --pivot data/pivot_set.json \\
        --out docs/plots/pivot_results_baseline.json
"""
from __future__ import annotations

import argparse
import json
import statistics
import sys
from pathlib import Path
from typing import Any, Dict, List

try:
    from subtext_arena.server.scenarios import load_scenarios
    from subtext_arena.train.train_grpo import (
        SYSTEM_PROMPT,
        build_full_observation,
        parse_final,
        reward_decomposition,
    )
except ImportError:
    ROOT = Path(__file__).resolve().parent.parent
    if str(ROOT) not in sys.path:
        sys.path.insert(0, str(ROOT))
    from server.scenarios import load_scenarios  # type: ignore[no-redef]
    from train.train_grpo import (  # type: ignore[no-redef]
        SYSTEM_PROMPT,
        build_full_observation,
        parse_final,
        reward_decomposition,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True,
                        help="Path to LoRA checkpoint, or 'baseline-only' to eval the base model.")
    parser.add_argument("--base-model", default="unsloth/Qwen2.5-3B-Instruct")
    parser.add_argument("--pivot", required=True)
    parser.add_argument("--out", required=True)
    parser.add_argument("--n-seeds", type=int, default=3)
    parser.add_argument("--max-tokens", type=int, default=768)
    parser.add_argument("--temperature", type=float, default=0.7)
    args = parser.parse_args()

    pivot = json.loads(Path(args.pivot).read_text())
    pivot_ids = pivot.get("clip_ids", [])
    if not pivot_ids:
        print("[eval] pivot set is empty - run curate_pivot_set.py first", file=sys.stderr)
        sys.exit(1)

    scenarios = load_scenarios()

    # Load model
    print(f"[load] {args.checkpoint}")
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=(args.base_model if args.checkpoint == "baseline-only" else args.checkpoint),
        max_seq_length=4096,
        load_in_4bit=True,
    )
    FastLanguageModel.for_inference(model)

    results: List[Dict[str, Any]] = []
    for clip_id in pivot_ids:
        gold = "sarcastic" if scenarios[clip_id]["sarcasm"] else "sincere"
        for seed in range(args.n_seeds):
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": build_full_observation(clip_id, scenarios)},
            ]
            encoded = tokenizer.apply_chat_template(
                messages, return_tensors="pt", add_generation_prompt=True,
            )
            input_ids = encoded.input_ids if hasattr(encoded, "input_ids") else encoded
            input_ids = input_ids.to(model.device)
            prompt_len = input_ids.shape[1]
            out = model.generate(
                input_ids=input_ids,
                max_new_tokens=args.max_tokens,
                do_sample=True,
                temperature=args.temperature,
                pad_token_id=tokenizer.eos_token_id,
            )
            text = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)

            decomp = reward_decomposition(text, gold)
            results.append({
                "clip_id": clip_id,
                "seed": seed,
                "gold": gold,
                "predicted": decomp["_predicted"],
                "confidence": decomp["_confidence"],
                "correct": decomp["_correct"],
                "well_formed": decomp["_well_formed"],
                "reward_total": decomp["_total"],
                "reward_components": {
                    "correctness": decomp["correctness"],
                    "reasoning_length": decomp["reasoning_length"],
                    "format": decomp["format"],
                },
                "completion_text": text[:1500],  # first 1500 chars for inspection
            })

    # Aggregate
    rewards = [r["reward_total"] for r in results]
    n_well_formed = sum(1 for r in results if r["well_formed"])
    n_correct = sum(1 for r in results if r["correct"])
    accuracy = n_correct / len(results) if results else 0.0

    # Per-clip mode (majority across seeds)
    per_clip: Dict[str, List[bool]] = {}
    for r in results:
        per_clip.setdefault(r["clip_id"], []).append(r["correct"])
    clip_accuracy = sum(1 for hits in per_clip.values() if sum(hits) > len(hits) / 2)

    print()
    print(f"[eval] checkpoint:        {args.checkpoint}")
    print(f"[eval] {len(results)} rollouts on {len(pivot_ids)} pivot clips")
    print(f"[eval] mean reward:        {statistics.mean(rewards):.3f} +/- {statistics.stdev(rewards) if len(rewards)>1 else 0.0:.3f}")
    print(f"[eval] well-formed rate:   {n_well_formed}/{len(results)} = {n_well_formed/len(results):.2%}")
    print(f"[eval] rollout accuracy:   {n_correct}/{len(results)} = {accuracy:.2%}")
    print(f"[eval] clip accuracy (maj): {clip_accuracy}/{len(pivot_ids)} = {clip_accuracy/len(pivot_ids):.2%}")
    print()
    print(f"[eval] HEADLINE NUMBER: {clip_accuracy}/{len(pivot_ids)} on the Prosody-Pivot Set")
    print()

    payload = {
        "checkpoint": args.checkpoint,
        "n_pivot_clips": len(pivot_ids),
        "n_rollouts": len(results),
        "mean_reward": statistics.mean(rewards),
        "stdev_reward": statistics.stdev(rewards) if len(rewards) > 1 else 0.0,
        "well_formed_rate": n_well_formed / len(results) if results else 0.0,
        "rollout_accuracy": accuracy,
        "clip_accuracy": clip_accuracy / len(pivot_ids) if pivot_ids else 0.0,
        "headline_n_correct": clip_accuracy,
        "headline_n_total": len(pivot_ids),
        "rollouts": results,
    }
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    Path(args.out).write_text(json.dumps(payload, indent=2))
    print(f"[eval] wrote {args.out}")


if __name__ == "__main__":
    main()