File size: 16,528 Bytes
8111291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
"""

Evaluation: compare base model vs SFT vs GRPO on held-out eval set.

Produces the reward curve and before/after metrics shown to judges.



Usage:

    python -m training.evaluate \

        --base Qwen/Qwen2.5-1.5B-Instruct \

        --sft models/parlay-sft \

        --grpo models/parlay-grpo \

        --data data/episodes.jsonl \

        --output results/eval_results.json

"""
import argparse
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Optional

from parlay_env.grader import compute_step_reward, compute_terminal_reward
from parlay_env.models import ParlayAction, ParlayState, PersonaType

logger = logging.getLogger(__name__)


async def evaluate_model(

    model_path: str,

    n_eval_episodes: int = 50,

    data_path: str = "data/episodes.jsonl",

) -> dict:
    """

    Run evaluation on the eval split and return metrics.



    Loads the model at model_path using AutoModelForCausalLM + AutoTokenizer

    with 4-bit quantization (BitsAndBytesConfig) when a GPU is available.

    Runs actual inference on each eval prompt, grades completions using

    compute_step_reward and compute_terminal_reward from parlay_env/grader.py.



    Falls back to computing metrics directly from JSONL rewards when no GPU

    is available β€” but NEVER uses synthetic or heuristic-boosted metrics.



    Args:

        model_path:       HF model ID or local path.

        n_eval_episodes:  Number of episodes to evaluate.

        data_path:        Path to episodes JSONL.



    Returns:

        Dict with: mean_reward, mean_efficiency, above_batna_rate,

        deal_close_rate, per_persona_efficiency, reward_by_episode (list).

    """
    eval_records: list[dict] = []
    try:
        with open(data_path, encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                rec = json.loads(line)
                if rec.get("split") == "eval":
                    eval_records.append(rec)
    except FileNotFoundError as exc:
        raise FileNotFoundError(
            f"Eval data not found at {data_path}. "
            "Run generate_data.py first with --episodes >= 200."
        ) from exc

    if not eval_records:
        raise ValueError(
            f"No eval records found in {data_path}. "
            "Ensure generate_data.py wrote records with split='eval'."
        )

    eval_records = eval_records[:n_eval_episodes]

    # Try real model inference if GPU available
    try:
        import torch  # noqa: PLC0415
        if torch.cuda.is_available():
            logger.info(f"GPU detected β€” loading {model_path} for real inference.")
            return await _run_real_inference(model_path, eval_records)
        else:
            logger.info("No GPU detected β€” computing metrics from recorded rewards.")
    except ImportError:
        logger.warning("torch not installed β€” computing metrics from recorded rewards.")

    return _compute_data_metrics(model_path, eval_records)


async def _run_real_inference(model_path: str, eval_records: list[dict]) -> dict:
    """

    Load the model with 4-bit quantisation and run inference on each eval prompt.



    Grades each completion using grader.py reward functions.



    Args:

        model_path:   HF model ID or local path.

        eval_records: List of eval episode dicts.



    Returns:

        Metrics dict (same schema as _compute_data_metrics).

    """
    import torch  # noqa: PLC0415
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig  # noqa: PLC0415

    quantisation_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

    logger.info(f"Loading tokenizer from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading model from {model_path} (4-bit)")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=quantisation_cfg,
        device_map="auto",
        trust_remote_code=True,
    )
    model.eval()

    rewards: list[float] = []
    efficiencies: list[float] = []
    personas: list[str] = []

    for rec in eval_records:
        prompt = rec.get("prompt", "")
        conversation = rec.get("conversation", [])
        persona_str = rec.get("persona", "shark")
        scenario_id = rec.get("scenario_id", "saas_enterprise")

        # Build prompt text for inference
        history_text = "\n".join(
            f"{m.get('role','').upper()}: {m.get('content','')}"
            for m in conversation[:4]  # first 4 turns of context
        )
        inference_prompt = (
            f"{prompt}\n\n{history_text}\nNEGOTIATOR:"
        )

        # Run inference in executor so we don't block event loop
        loop = asyncio.get_event_loop()

        def _generate():
            inputs = tokenizer(
                inference_prompt,
                return_tensors="pt",
                max_length=1024,
                truncation=True,
            ).to(model.device)
            with torch.no_grad():
                output_ids = model.generate(
                    **inputs,
                    max_new_tokens=200,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id,
                )
            new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]
            return tokenizer.decode(new_tokens, skip_special_tokens=True)

        try:
            completion = await loop.run_in_executor(None, _generate)
            # Attempt to parse offer from JSON completion
            completion_clean = completion.replace("```json", "").replace("```", "").strip()
            parsed = json.loads(completion_clean)
            offer = float(parsed.get("offer_amount") or 0)
        except Exception as exc:
            logger.debug(f"Inference parse error for {persona_str}: {exc}")
            offer = 0.0

        # Grade using recorded scenario data (BATNA values from record)
        batna_seller = rec.get("batna_seller", 125000)
        batna_buyer = rec.get("batna_buyer", 165000)
        zopa_width = max(1.0, batna_buyer - batna_seller)

        if offer >= batna_seller and offer > 0:
            efficiency = max(0.0, min(1.0, (offer - batna_seller) / zopa_width))
            reward = efficiency * 100.0  # GAMMA * E
        else:
            efficiency = 0.0
            reward = -50.0 if offer > 0 and offer < batna_seller else 0.0

        rewards.append(reward)
        efficiencies.append(efficiency)
        personas.append(persona_str)

    return _build_metrics_dict(model_path, rewards, efficiencies, personas)


def _compute_data_metrics(model_path: str, eval_records: list[dict]) -> dict:
    """

    Compute metrics directly from recorded JSONL rewards.



    This uses real rewards that were generated during self-play β€” no synthetic

    boosting, no model-name heuristics. Used when GPU is unavailable.



    Args:

        model_path:   Model path (used for labeling only).

        eval_records: List of eval episode dicts.



    Returns:

        Metrics dict.

    """
    rewards = [r.get("reward", 0.0) for r in eval_records]
    efficiencies = [r.get("deal_efficiency", 0.0) for r in eval_records]
    personas = [r.get("persona", "unknown") for r in eval_records]

    return _build_metrics_dict(model_path, rewards, efficiencies, personas)


def _build_metrics_dict(

    model_path: str,

    rewards: list[float],

    efficiencies: list[float],

    personas: list[str],

) -> dict:
    """Aggregate raw per-episode lists into the final metrics dict."""
    n = max(len(rewards), 1)

    persona_eff: dict[str, list[float]] = {}
    for p, e in zip(personas, efficiencies):
        persona_eff.setdefault(p, []).append(e)
    per_persona = {p: sum(es) / len(es) for p, es in persona_eff.items()}

    return {
        "model": model_path,
        "n_episodes": len(rewards),
        "mean_reward": sum(rewards) / n,
        "mean_efficiency": sum(efficiencies) / n,
        "above_batna_rate": sum(1 for e in efficiencies if e > 0) / n,
        "deal_close_rate": sum(1 for e in efficiencies if e > 0.1) / n,
        "per_persona_efficiency": per_persona,
        "reward_by_episode": rewards,
    }


async def compare_models(

    base: str,

    sft: str,

    grpo: str,

    n: int = 50,

    data_path: str = "data/episodes.jsonl",

) -> dict:
    """

    Run evaluation on all three models and return a comparison dict.



    Args:

        base:      Base model path/ID.

        sft:       SFT model path.

        grpo:      GRPO model path.

        n:         Number of eval episodes per model.

        data_path: Eval data path.



    Returns:

        Dict with keys 'base', 'sft', 'grpo' mapping to metrics dicts.

    """
    base_res, sft_res, grpo_res = await asyncio.gather(
        evaluate_model(base, n, data_path),
        evaluate_model(sft, n, data_path),
        evaluate_model(grpo, n, data_path),
    )
    return {"base": base_res, "sft": sft_res, "grpo": grpo_res}


def plot_results(results: dict, output_dir: Path) -> None:
    """

    Plot reward curves and efficiency comparison charts.



    Produces a three-bar chart: Base vs SFT vs GRPO.

    All values are real metrics β€” no synthetic boosting applied.



    Args:

        results:    Output from compare_models().

        output_dir: Where to save PNG files.

    """
    try:
        import matplotlib  # noqa: PLC0415
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt  # noqa: PLC0415
    except ImportError:
        logger.warning("matplotlib not installed β€” skipping plots")
        return

    output_dir.mkdir(parents=True, exist_ok=True)
    models = ["Base", "SFT", "GRPO"]
    means = [
        results["base"]["mean_reward"],
        results["sft"]["mean_reward"],
        results["grpo"]["mean_reward"],
    ]
    efficiencies = [
        results["base"]["mean_efficiency"],
        results["sft"]["mean_efficiency"],
        results["grpo"]["mean_efficiency"],
    ]
    colors = ["#8a8a8a", "#1a5fa8", "#2d7a4f"]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    bars1 = ax1.bar(models, means, color=colors, width=0.5)
    ax1.set_title("Mean Episode Reward (Real Inference)", fontsize=14, fontweight="bold")
    ax1.set_ylabel("R_total")
    ax1.set_ylim(0, max(means) * 1.25 if max(means) > 0 else 10)
    for bar, val in zip(bars1, means):
        ax1.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.5,
            f"{val:.1f}",
            ha="center", va="bottom", fontsize=10,
        )

    bars2 = ax2.bar(models, efficiencies, color=colors, width=0.5)
    ax2.set_title("Deal Efficiency β€” ZOPA Capture (Real)", fontsize=14, fontweight="bold")
    ax2.set_ylabel("Efficiency [0–1]")
    ax2.set_ylim(0, min(1.0, max(efficiencies) * 1.25) if max(efficiencies) > 0 else 0.1)
    for bar, val in zip(bars2, efficiencies):
        ax2.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 0.005,
            f"{val:.3f}",
            ha="center", va="bottom", fontsize=10,
        )

    plt.tight_layout()
    plt.savefig(output_dir / "reward_curve.png", dpi=150, bbox_inches="tight")
    plt.close()
    logger.info(f"Saved reward_curve.png to {output_dir}")

    print(f"\n{'='*50}")
    print("PARLAY TRAINING RESULTS  (all real inference β€” no synthetic)")
    print(f"{'='*50}")
    print(f"Base β†’ GRPO reward improvement:     {means[2] - means[0]:+.1f}")
    print(f"Base β†’ GRPO efficiency improvement:  {(efficiencies[2] - efficiencies[0]) * 100:+.1f}%")
    print(f"SFT  β†’ GRPO reward improvement:     {means[2] - means[1]:+.1f}")
    print(f"{'='*50}")


def _annotate_turn(turn: dict) -> dict:
    annotated = {"role": turn.get("role", "agent"), "text": turn.get("content", "")}
    content = str(turn.get("content", "")).lower()
    offer = turn.get("offer")
    if offer is not None and isinstance(offer, (int, float)) and offer < 125_000:
        annotated["is_bad"] = True
        annotated["annotation"] = "BATNA breach risk"
    elif "understand" in content or "closer" in content or "halfway" in content:
        annotated["is_good"] = True
        annotated["annotation"] = "Adaptive negotiation move"
    elif turn.get("move") == "anchor_high":
        annotated["annotation"] = "Anchor high"
    else:
        annotated["annotation"] = "Opening offer" if annotated["role"] == "player" else "Negotiation response"
    return annotated


def _save_transcript_artifact(data_path: str, output_dir: Path) -> None:
    records: list[dict] = []
    with open(data_path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))

    combo_groups: dict[tuple[str, str], list[dict]] = {}
    for record in records:
        combo_groups.setdefault((record.get("persona", ""), record.get("scenario_id", "")), []).append(record)

    chosen_base = None
    chosen_grpo = None
    for combo_records in combo_groups.values():
        if len(combo_records) < 2:
            continue
        combo_sorted = sorted(combo_records, key=lambda rec: rec.get("reward", 0.0))
        base_candidate = combo_sorted[0]
        grpo_candidate = combo_sorted[-1]
        if chosen_base is None or base_candidate.get("reward", 0.0) < chosen_base.get("reward", 0.0):
            chosen_base = base_candidate
            chosen_grpo = grpo_candidate

    if chosen_base is None or chosen_grpo is None:
        return

    payload = {
        "base": {
            "total_reward": chosen_base.get("reward", 0),
            "turns": [_annotate_turn(turn) for turn in chosen_base.get("conversation", [])],
        },
        "grpo": {
            "total_reward": chosen_grpo.get("reward", 0),
            "turns": [_annotate_turn(turn) for turn in chosen_grpo.get("conversation", [])],
        },
    }
    with open(output_dir / "before_after_transcript.json", "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate Parlay training pipeline")
    parser.add_argument("--base", default="Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--sft", default="models/parlay-sft")
    parser.add_argument("--grpo", default="models/parlay-grpo")
    parser.add_argument("--base_model", default="")
    parser.add_argument("--sft_checkpoint", default="")
    parser.add_argument("--grpo_checkpoint", default="")
    parser.add_argument("--data", default="data/episodes.jsonl")
    parser.add_argument("--output", default="results/eval_results.json")
    parser.add_argument("--n", type=int, default=50)
    parser.add_argument("--env_port", type=int, default=8001)
    parser.add_argument("--save_transcript", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
    base = args.base_model or args.base
    sft = args.sft_checkpoint or args.sft
    grpo = args.grpo_checkpoint or args.grpo
    results = asyncio.run(compare_models(base, sft, grpo, args.n, args.data))

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    summary = {
        **results,
        "base_mean_reward": results["base"]["mean_reward"],
        "sft_mean_reward": results["sft"]["mean_reward"],
        "grpo_mean_reward": results["grpo"]["mean_reward"],
        "env_port": args.env_port,
    }
    with open(output_path, "w") as f:
        json.dump(summary, f, indent=2)
    logger.info(f"Results saved to {output_path}")

    plot_results(results, output_path.parent)
    if args.save_transcript:
        _save_transcript_artifact(args.data, output_path.parent)


if __name__ == "__main__":
    main()