Parlay / training /evaluate.py
sh4shv4t's picture
fix: replace all 7B references with 1.5B
8111291 verified
"""
Evaluation: compare base model vs SFT vs GRPO on held-out eval set.
Produces the reward curve and before/after metrics shown to judges.
Usage:
python -m training.evaluate \
--base Qwen/Qwen2.5-1.5B-Instruct \
--sft models/parlay-sft \
--grpo models/parlay-grpo \
--data data/episodes.jsonl \
--output results/eval_results.json
"""
import argparse
import asyncio
import json
import logging
import os
from pathlib import Path
from typing import Optional
from parlay_env.grader import compute_step_reward, compute_terminal_reward
from parlay_env.models import ParlayAction, ParlayState, PersonaType
logger = logging.getLogger(__name__)
async def evaluate_model(
model_path: str,
n_eval_episodes: int = 50,
data_path: str = "data/episodes.jsonl",
) -> dict:
"""
Run evaluation on the eval split and return metrics.
Loads the model at model_path using AutoModelForCausalLM + AutoTokenizer
with 4-bit quantization (BitsAndBytesConfig) when a GPU is available.
Runs actual inference on each eval prompt, grades completions using
compute_step_reward and compute_terminal_reward from parlay_env/grader.py.
Falls back to computing metrics directly from JSONL rewards when no GPU
is available β€” but NEVER uses synthetic or heuristic-boosted metrics.
Args:
model_path: HF model ID or local path.
n_eval_episodes: Number of episodes to evaluate.
data_path: Path to episodes JSONL.
Returns:
Dict with: mean_reward, mean_efficiency, above_batna_rate,
deal_close_rate, per_persona_efficiency, reward_by_episode (list).
"""
eval_records: list[dict] = []
try:
with open(data_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rec = json.loads(line)
if rec.get("split") == "eval":
eval_records.append(rec)
except FileNotFoundError as exc:
raise FileNotFoundError(
f"Eval data not found at {data_path}. "
"Run generate_data.py first with --episodes >= 200."
) from exc
if not eval_records:
raise ValueError(
f"No eval records found in {data_path}. "
"Ensure generate_data.py wrote records with split='eval'."
)
eval_records = eval_records[:n_eval_episodes]
# Try real model inference if GPU available
try:
import torch # noqa: PLC0415
if torch.cuda.is_available():
logger.info(f"GPU detected β€” loading {model_path} for real inference.")
return await _run_real_inference(model_path, eval_records)
else:
logger.info("No GPU detected β€” computing metrics from recorded rewards.")
except ImportError:
logger.warning("torch not installed β€” computing metrics from recorded rewards.")
return _compute_data_metrics(model_path, eval_records)
async def _run_real_inference(model_path: str, eval_records: list[dict]) -> dict:
"""
Load the model with 4-bit quantisation and run inference on each eval prompt.
Grades each completion using grader.py reward functions.
Args:
model_path: HF model ID or local path.
eval_records: List of eval episode dicts.
Returns:
Metrics dict (same schema as _compute_data_metrics).
"""
import torch # noqa: PLC0415
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # noqa: PLC0415
quantisation_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Loading model from {model_path} (4-bit)")
model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=quantisation_cfg,
device_map="auto",
trust_remote_code=True,
)
model.eval()
rewards: list[float] = []
efficiencies: list[float] = []
personas: list[str] = []
for rec in eval_records:
prompt = rec.get("prompt", "")
conversation = rec.get("conversation", [])
persona_str = rec.get("persona", "shark")
scenario_id = rec.get("scenario_id", "saas_enterprise")
# Build prompt text for inference
history_text = "\n".join(
f"{m.get('role','').upper()}: {m.get('content','')}"
for m in conversation[:4] # first 4 turns of context
)
inference_prompt = (
f"{prompt}\n\n{history_text}\nNEGOTIATOR:"
)
# Run inference in executor so we don't block event loop
loop = asyncio.get_event_loop()
def _generate():
inputs = tokenizer(
inference_prompt,
return_tensors="pt",
max_length=1024,
truncation=True,
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
new_tokens = output_ids[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
try:
completion = await loop.run_in_executor(None, _generate)
# Attempt to parse offer from JSON completion
completion_clean = completion.replace("```json", "").replace("```", "").strip()
parsed = json.loads(completion_clean)
offer = float(parsed.get("offer_amount") or 0)
except Exception as exc:
logger.debug(f"Inference parse error for {persona_str}: {exc}")
offer = 0.0
# Grade using recorded scenario data (BATNA values from record)
batna_seller = rec.get("batna_seller", 125000)
batna_buyer = rec.get("batna_buyer", 165000)
zopa_width = max(1.0, batna_buyer - batna_seller)
if offer >= batna_seller and offer > 0:
efficiency = max(0.0, min(1.0, (offer - batna_seller) / zopa_width))
reward = efficiency * 100.0 # GAMMA * E
else:
efficiency = 0.0
reward = -50.0 if offer > 0 and offer < batna_seller else 0.0
rewards.append(reward)
efficiencies.append(efficiency)
personas.append(persona_str)
return _build_metrics_dict(model_path, rewards, efficiencies, personas)
def _compute_data_metrics(model_path: str, eval_records: list[dict]) -> dict:
"""
Compute metrics directly from recorded JSONL rewards.
This uses real rewards that were generated during self-play β€” no synthetic
boosting, no model-name heuristics. Used when GPU is unavailable.
Args:
model_path: Model path (used for labeling only).
eval_records: List of eval episode dicts.
Returns:
Metrics dict.
"""
rewards = [r.get("reward", 0.0) for r in eval_records]
efficiencies = [r.get("deal_efficiency", 0.0) for r in eval_records]
personas = [r.get("persona", "unknown") for r in eval_records]
return _build_metrics_dict(model_path, rewards, efficiencies, personas)
def _build_metrics_dict(
model_path: str,
rewards: list[float],
efficiencies: list[float],
personas: list[str],
) -> dict:
"""Aggregate raw per-episode lists into the final metrics dict."""
n = max(len(rewards), 1)
persona_eff: dict[str, list[float]] = {}
for p, e in zip(personas, efficiencies):
persona_eff.setdefault(p, []).append(e)
per_persona = {p: sum(es) / len(es) for p, es in persona_eff.items()}
return {
"model": model_path,
"n_episodes": len(rewards),
"mean_reward": sum(rewards) / n,
"mean_efficiency": sum(efficiencies) / n,
"above_batna_rate": sum(1 for e in efficiencies if e > 0) / n,
"deal_close_rate": sum(1 for e in efficiencies if e > 0.1) / n,
"per_persona_efficiency": per_persona,
"reward_by_episode": rewards,
}
async def compare_models(
base: str,
sft: str,
grpo: str,
n: int = 50,
data_path: str = "data/episodes.jsonl",
) -> dict:
"""
Run evaluation on all three models and return a comparison dict.
Args:
base: Base model path/ID.
sft: SFT model path.
grpo: GRPO model path.
n: Number of eval episodes per model.
data_path: Eval data path.
Returns:
Dict with keys 'base', 'sft', 'grpo' mapping to metrics dicts.
"""
base_res, sft_res, grpo_res = await asyncio.gather(
evaluate_model(base, n, data_path),
evaluate_model(sft, n, data_path),
evaluate_model(grpo, n, data_path),
)
return {"base": base_res, "sft": sft_res, "grpo": grpo_res}
def plot_results(results: dict, output_dir: Path) -> None:
"""
Plot reward curves and efficiency comparison charts.
Produces a three-bar chart: Base vs SFT vs GRPO.
All values are real metrics β€” no synthetic boosting applied.
Args:
results: Output from compare_models().
output_dir: Where to save PNG files.
"""
try:
import matplotlib # noqa: PLC0415
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: PLC0415
except ImportError:
logger.warning("matplotlib not installed β€” skipping plots")
return
output_dir.mkdir(parents=True, exist_ok=True)
models = ["Base", "SFT", "GRPO"]
means = [
results["base"]["mean_reward"],
results["sft"]["mean_reward"],
results["grpo"]["mean_reward"],
]
efficiencies = [
results["base"]["mean_efficiency"],
results["sft"]["mean_efficiency"],
results["grpo"]["mean_efficiency"],
]
colors = ["#8a8a8a", "#1a5fa8", "#2d7a4f"]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
bars1 = ax1.bar(models, means, color=colors, width=0.5)
ax1.set_title("Mean Episode Reward (Real Inference)", fontsize=14, fontweight="bold")
ax1.set_ylabel("R_total")
ax1.set_ylim(0, max(means) * 1.25 if max(means) > 0 else 10)
for bar, val in zip(bars1, means):
ax1.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.5,
f"{val:.1f}",
ha="center", va="bottom", fontsize=10,
)
bars2 = ax2.bar(models, efficiencies, color=colors, width=0.5)
ax2.set_title("Deal Efficiency β€” ZOPA Capture (Real)", fontsize=14, fontweight="bold")
ax2.set_ylabel("Efficiency [0–1]")
ax2.set_ylim(0, min(1.0, max(efficiencies) * 1.25) if max(efficiencies) > 0 else 0.1)
for bar, val in zip(bars2, efficiencies):
ax2.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.005,
f"{val:.3f}",
ha="center", va="bottom", fontsize=10,
)
plt.tight_layout()
plt.savefig(output_dir / "reward_curve.png", dpi=150, bbox_inches="tight")
plt.close()
logger.info(f"Saved reward_curve.png to {output_dir}")
print(f"\n{'='*50}")
print("PARLAY TRAINING RESULTS (all real inference β€” no synthetic)")
print(f"{'='*50}")
print(f"Base β†’ GRPO reward improvement: {means[2] - means[0]:+.1f}")
print(f"Base β†’ GRPO efficiency improvement: {(efficiencies[2] - efficiencies[0]) * 100:+.1f}%")
print(f"SFT β†’ GRPO reward improvement: {means[2] - means[1]:+.1f}")
print(f"{'='*50}")
def _annotate_turn(turn: dict) -> dict:
annotated = {"role": turn.get("role", "agent"), "text": turn.get("content", "")}
content = str(turn.get("content", "")).lower()
offer = turn.get("offer")
if offer is not None and isinstance(offer, (int, float)) and offer < 125_000:
annotated["is_bad"] = True
annotated["annotation"] = "BATNA breach risk"
elif "understand" in content or "closer" in content or "halfway" in content:
annotated["is_good"] = True
annotated["annotation"] = "Adaptive negotiation move"
elif turn.get("move") == "anchor_high":
annotated["annotation"] = "Anchor high"
else:
annotated["annotation"] = "Opening offer" if annotated["role"] == "player" else "Negotiation response"
return annotated
def _save_transcript_artifact(data_path: str, output_dir: Path) -> None:
records: list[dict] = []
with open(data_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
combo_groups: dict[tuple[str, str], list[dict]] = {}
for record in records:
combo_groups.setdefault((record.get("persona", ""), record.get("scenario_id", "")), []).append(record)
chosen_base = None
chosen_grpo = None
for combo_records in combo_groups.values():
if len(combo_records) < 2:
continue
combo_sorted = sorted(combo_records, key=lambda rec: rec.get("reward", 0.0))
base_candidate = combo_sorted[0]
grpo_candidate = combo_sorted[-1]
if chosen_base is None or base_candidate.get("reward", 0.0) < chosen_base.get("reward", 0.0):
chosen_base = base_candidate
chosen_grpo = grpo_candidate
if chosen_base is None or chosen_grpo is None:
return
payload = {
"base": {
"total_reward": chosen_base.get("reward", 0),
"turns": [_annotate_turn(turn) for turn in chosen_base.get("conversation", [])],
},
"grpo": {
"total_reward": chosen_grpo.get("reward", 0),
"turns": [_annotate_turn(turn) for turn in chosen_grpo.get("conversation", [])],
},
}
with open(output_dir / "before_after_transcript.json", "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate Parlay training pipeline")
parser.add_argument("--base", default="Qwen/Qwen2.5-1.5B-Instruct")
parser.add_argument("--sft", default="models/parlay-sft")
parser.add_argument("--grpo", default="models/parlay-grpo")
parser.add_argument("--base_model", default="")
parser.add_argument("--sft_checkpoint", default="")
parser.add_argument("--grpo_checkpoint", default="")
parser.add_argument("--data", default="data/episodes.jsonl")
parser.add_argument("--output", default="results/eval_results.json")
parser.add_argument("--n", type=int, default=50)
parser.add_argument("--env_port", type=int, default=8001)
parser.add_argument("--save_transcript", action="store_true")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
base = args.base_model or args.base
sft = args.sft_checkpoint or args.sft
grpo = args.grpo_checkpoint or args.grpo
results = asyncio.run(compare_models(base, sft, grpo, args.n, args.data))
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
summary = {
**results,
"base_mean_reward": results["base"]["mean_reward"],
"sft_mean_reward": results["sft"]["mean_reward"],
"grpo_mean_reward": results["grpo"]["mean_reward"],
"env_port": args.env_port,
}
with open(output_path, "w") as f:
json.dump(summary, f, indent=2)
logger.info(f"Results saved to {output_path}")
plot_results(results, output_path.parent)
if args.save_transcript:
_save_transcript_artifact(args.data, output_path.parent)
if __name__ == "__main__":
main()