"""Generate quality-filtered negotiation episodes via Gemini self-play.""" import argparse import asyncio import json import logging import os import random import statistics import sys from collections import Counter, defaultdict from pathlib import Path # Repo root on path when run as `python training/generate_data.py` (script dir is training/) sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from dotenv import load_dotenv from agent.gemini_client import get_and_reset_counts, set_quiet from agent.runner import EpisodeResult, run_episode from game.scenarios import SCENARIOS from parlay_env.models import PersonaType logger = logging.getLogger(__name__) DIVERSITY_CONFIG = { "noise_injection_rate": 0.3, "drift_force_rate": 0.4, } REQUIRED_COMBINATIONS = [ (persona, scenario) for persona in ["shark", "diplomat", "veteran"] for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"] ] # Weighted to oversample historically low deal-rate combinations (total weight = 15) COMBO_WEIGHTS: dict[tuple[str, str], int] = { ("veteran", "hiring_package"): 3, ("veteran", "saas_enterprise"): 2, ("veteran", "acquisition_term_sheet"): 2, ("shark", "hiring_package"): 2, ("diplomat", "hiring_package"): 2, ("shark", "saas_enterprise"): 1, ("shark", "acquisition_term_sheet"): 1, ("diplomat", "saas_enterprise"): 1, ("diplomat", "acquisition_term_sheet"): 1, } WEIGHTED_COMBO_LIST: list[tuple[str, str]] = [] for _pair, _weight in COMBO_WEIGHTS.items(): WEIGHTED_COMBO_LIST.extend([_pair] * _weight) def _row_total_reward(record: dict) -> float | None: v = record.get("reward") if v is not None: return float(v) v2 = record.get("cumulative_reward") if v2 is not None: return float(v2) return None def is_quality_episode(grade, args) -> tuple[bool, str]: """ Returns (keep: bool, reason: str). """ if not args.quality_filter: return True, "no_filter" if grade.deal_efficiency >= args.min_efficiency: return True, "deal_efficiency" if grade.termination_reason == "walk_away" and grade.total_reward > -200: return True, "principled_walkaway" if grade.drift_adapted: return True, "drift_adapted" if grade.tom_accuracy_avg >= 0.5: return True, "good_tom" return False, f"low_quality (eff={grade.deal_efficiency:.2f}, tom={grade.tom_accuracy_avg:.2f})" def _grade_proxy_from_record(record: dict) -> object: return type( "GradeProxy", (), { "deal_efficiency": record["deal_efficiency"], "termination_reason": record["termination_reason"], "total_reward": record["reward"], "drift_adapted": record["drift_adapted"], "tom_accuracy_avg": record["tom_accuracy_avg"], }, )() def _record_from_result(persona: str, scenario_id: str, result: EpisodeResult) -> dict: return { "prompt": result.system_prompt, "conversation": [{k: v for k, v in msg.items()} for msg in result.conversation], "reward": result.grade.total_reward, "deal_efficiency": result.grade.deal_efficiency, "persona": persona, "scenario_id": scenario_id, "acts_completed": 1, "tom_accuracy": result.grade.tom_accuracy_avg, "tom_accuracy_avg": result.grade.tom_accuracy_avg, "drift_adapted": result.grade.drift_adapted, "split": "train" if random.random() < 0.9 else "eval", "deal_reached": result.final_price is not None, "episode_id": result.session.session_id, "termination_reason": result.grade.termination_reason, "batna_seller": result.session.hidden_state.walk_away_price, "batna_buyer": result.session.hidden_state.budget_ceiling, } async def _run_episode_full( persona: str, scenario_id: str, seed: int, max_turns: int ) -> tuple[dict | None, EpisodeResult | None]: try: result = await run_episode( persona=PersonaType(persona), scenario_id=scenario_id, inject_noise=random.random() < DIVERSITY_CONFIG["noise_injection_rate"], force_drift=random.random() < DIVERSITY_CONFIG["drift_force_rate"], seed=seed, max_turns=max_turns, ) except Exception as exc: logger.warning("Episode failed (%s, %s): %s", persona, scenario_id, exc) return None, None return _record_from_result(persona, scenario_id, result), result async def _run_one(persona: str, scenario_id: str, seed: int, max_turns: int) -> dict | None: record, _ = await _run_episode_full(persona, scenario_id, seed, max_turns) return record def _classify_discard(grade, args) -> str: """Single bucket per discarded episode (mutually exclusive).""" if grade.deal_efficiency < args.min_efficiency: return "low_efficiency_no_deal" if grade.tom_accuracy_avg < 0.5: return "tom_accuracy_below_threshold" return "other" def _conversation_mentions_market(conversation: list[dict]) -> bool: for msg in conversation: for v in msg.values(): if isinstance(v, str) and "market" in v.lower(): return True return False def _print_inspect_report( coverage: dict[tuple[str, str], int], total_pre: int, kept: int, discarded: int, keep_reason_counts: Counter[str], kept_records: list[dict], kept_tom: list[bool], kept_rewards: list[float], kept_eff: list[float], kept_tom_acc: list[float], kept_turns: list[float], div_drift: int, div_market: int, div_bluff: int, div_zopa: int, discard_by_label: Counter[str], ) -> None: n_k = max(kept, 1) pct = lambda x: 100.0 * x / n_k d_rate = 100.0 * discarded / max(total_pre, 1) def st(values: list[float]) -> str: if len(values) < 2: return "0.00" return f"{statistics.stdev(values):.2f}" def mean_t(values: list[float]) -> str: if not values: return "0.00" return f"{statistics.mean(values):.2f}" mean_turns = statistics.mean(kept_turns) if kept_turns else 0.0 deal_n = sum(1 for r in kept_records if r.get("deal_reached")) walk_n = keep_reason_counts.get("principled_walkaway", 0) drift_n = keep_reason_counts.get("drift_adapted", 0) tom5_n = sum(1 for t in kept_tom if t) r1, r2, r3 = ( discard_by_label.get("low_efficiency_no_deal", 0), discard_by_label.get("tom_accuracy_below_threshold", 0), discard_by_label.get("other", 0), ) _lw = 31 print() print("=== QUALITY REPORT (60 episodes) ===") print() print("Coverage (persona × scenario):") for persona, scenario_id in REQUIRED_COMBINATIONS: n = coverage.get((persona, scenario_id), 0) print(f" {persona:8s} × {scenario_id:30s} : {n} episodes") print() print("Quality filter:") print(f" {'Total generated (before filter)':<{_lw}}: {total_pre}") print(f" {'Kept after filter':<{_lw}}: {kept}") print(f" {'Discarded':<{_lw}}: {discarded}") print(f" {'Discard rate':<{_lw}}: {d_rate:.1f}%") print() print("Kept episode breakdown:") print(f" Deal reached : {deal_n:3d} ({pct(deal_n):.1f}%)") print(f" Principled walkaway : {walk_n:3d} ({pct(walk_n):.1f}%)") print(f" Drift adapted : {drift_n:3d} ({pct(drift_n):.1f}%)") print(f" ToM accuracy >= 0.5 : {tom5_n:3d} ({pct(tom5_n):.1f}%)") print() print("Reward stats (kept episodes only):") print(f" Mean cumulative reward : {mean_t(kept_rewards)}") print(f" Std cumulative reward : {st(kept_rewards)}") print(f" Min : {min(kept_rewards) if kept_rewards else 0.0:.2f}") print(f" Max : {max(kept_rewards) if kept_rewards else 0.0:.2f}") print(f" Mean deal efficiency : {mean_t(kept_eff)}") print(f" Mean ToM accuracy : {mean_t(kept_tom_acc)}") print(f" Mean turns to close : {mean_turns:.1f}") print() print("Diversity flags (kept episodes):") print( f" {'Episodes with drift event':<{_lw}}: {div_drift:3d} ({100.0 * div_drift / n_k:.1f}%)" ) print( f" {'Episodes with market event':<{_lw}}: {div_market:3d} ({100.0 * div_market / n_k:.1f}%)" ) print( f" {'Episodes with bluff caught':<{_lw}}: {div_bluff:3d} ({100.0 * div_bluff / n_k:.1f}%)" ) print( f" {'Episodes with ZOPA erosion':<{_lw}}: {div_zopa:3d} ({100.0 * div_zopa / n_k:.1f}%)" ) print() print("Top 3 discard reasons:") print(f" 1. low_efficiency_no_deal : {r1}") print(f" 2. tom_accuracy_below_threshold: {r2}") print(f" 3. other : {r3}") async def run_inspect_mode(args) -> None: out_path = Path(getattr(args, "inspect_output", "data/inspect_run.jsonl")) out_path.parent.mkdir(parents=True, exist_ok=True) coverage: dict[tuple[str, str], int] = defaultdict(int) keep_reason_counts: Counter[str] = Counter() kept_records: list[dict] = [] kept_tom: list[bool] = [] kept_rewards: list[float] = [] kept_eff: list[float] = [] kept_tom_acc: list[float] = [] kept_turns: list[float] = [] div_drift = div_market = div_bluff = div_zopa = 0 discard_by_label: Counter[str] = Counter() total_pre = 60 discarded = 0 seed = 0 n_inspect = 60 for i in range(n_inspect): persona, scenario_id = REQUIRED_COMBINATIONS[i % len(REQUIRED_COMBINATIONS)] record, res = await _run_episode_full( persona, scenario_id, seed=seed, max_turns=args.max_turns ) seed += 1 coverage[(persona, scenario_id)] += 1 if record is None or res is None: discarded += 1 discard_by_label["other"] += 1 continue g = res.grade proxy = _grade_proxy_from_record(record) keep, reason = is_quality_episode(proxy, args) if not keep: discarded += 1 discard_by_label[_classify_discard(g, args)] += 1 continue keep_reason_counts[reason] += 1 kept_rewards.append(record["reward"]) kept_eff.append(record["deal_efficiency"]) kept_tom_acc.append(record["tom_accuracy_avg"]) kept_turns.append(float(res.session.step_count)) kept_tom.append(record["tom_accuracy_avg"] >= 0.5) kept_records.append(record) s = res.session if record["drift_adapted"]: div_drift += 1 if _conversation_mentions_market(res.conversation): div_market += 1 if s.bluffs_caught > 0 or g.bluffs_caught > 0: div_bluff += 1 if s.zopa_erosion_ticks > 0: div_zopa += 1 with open(out_path, "w", encoding="utf-8") as out_f: for r in kept_records: out_f.write(json.dumps(r, ensure_ascii=False) + "\n") _print_inspect_report( coverage, total_pre=total_pre, kept=len(kept_records), discarded=discarded, keep_reason_counts=keep_reason_counts, kept_records=kept_records, kept_tom=kept_tom, kept_rewards=kept_rewards, kept_eff=kept_eff, kept_tom_acc=kept_tom_acc, kept_turns=kept_turns, div_drift=div_drift, div_market=div_market, div_bluff=div_bluff, div_zopa=div_zopa, discard_by_label=discard_by_label, ) print() print(f"Kept episodes written to: {out_path.resolve()}") async def run_diversity_pass(args, output_path: Path) -> None: """ Generate a quality-filtered dataset; persona x scenario is weighted-sampled (see COMBO_WEIGHTS / WEIGHTED_COMBO_LIST). """ output_path.parent.mkdir(parents=True, exist_ok=True) coverage: dict[tuple[str, str], int] = defaultdict(int) kept_reason_counts: Counter[str] = Counter() kept_records: list[dict] = [] generated = 0 discarded = 0 skipped_min_reward = 0 seed = 0 total_live_calls: int = 0 total_fallback_calls: int = 0 _verbose = not getattr(args, "quiet", False) _checkpoints = {20, 40, 60, 80, 100, 120, 140} def _emit_checkpoint(_ep_num: int) -> None: if not _verbose or _ep_num not in _checkpoints: return _all_rewards = [r.get("reward", 0.0) for r in kept_records] _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records] _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records}) print(f"\n{'━' * 40}", file=sys.stderr) print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr) print( f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)", file=sys.stderr, ) print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr) print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr) print(f" Combos covered : {_combos_covered}/9", file=sys.stderr) print(f" Min-reward skip : {skipped_min_reward}", file=sys.stderr) print(f" Live calls total: {total_live_calls}", file=sys.stderr) print(f" Fallback total : {total_fallback_calls}", file=sys.stderr) print(f"{'━' * 40}\n", file=sys.stderr) with open(output_path, "w", encoding="utf-8") as out_f: while len(kept_records) < args.episodes: persona, scenario_id = random.choice(WEIGHTED_COMBO_LIST) record = await _run_one(persona, scenario_id, seed=seed, max_turns=args.max_turns) seed += 1 generated += 1 if record is None: _live_n, _fall_n = get_and_reset_counts() total_live_calls += _live_n total_fallback_calls += _fall_n continue rw = _row_total_reward(record) if rw is not None and rw < args.min_reward: skipped_min_reward += 1 _live_m, _fall_m = get_and_reset_counts() total_live_calls += _live_m total_fallback_calls += _fall_m if _verbose: print( f"[min_reward skip #{skipped_min_reward}] {persona} x {scenario_id} " f"reward={rw:.2f} < {args.min_reward}", file=sys.stderr, ) continue keep, reason = is_quality_episode( _grade_proxy_from_record(record), args, ) if not keep: discarded += 1 _live_d, _fall_d = get_and_reset_counts() total_live_calls += _live_d total_fallback_calls += _fall_d if _verbose: print( f"[EP --/{args.episodes:02d}] " f"{persona}×{scenario_id:<27s} | " f"reward={record.get('reward', 0.0):+.2f} | " f"eff={record.get('deal_efficiency', 0.0):.3f} | " f"kept=NO | " f"total_kept={len(kept_records)}/{generated} | " f"gemini_live={_live_d} fallback={_fall_d}", file=sys.stderr, ) continue out_f.write(json.dumps(record, ensure_ascii=False) + "\n") out_f.flush() kept_records.append(record) _live, _fall = get_and_reset_counts() total_live_calls += _live total_fallback_calls += _fall _ep_num = len(kept_records) if _verbose: _reward = record.get("reward", 0.0) _eff = record.get("deal_efficiency", 0.0) _combo = f"{record['persona']}×{record['scenario_id']}" print( f"[EP {_ep_num:02d}/{args.episodes:02d}] " f"{_combo:<35s} | " f"reward={_reward:+.2f} | " f"eff={_eff:.3f} | " f"kept=YES | " f"total_kept={_ep_num}/{generated} | " f"gemini_live={_live} fallback={_fall}", file=sys.stderr, ) _emit_checkpoint(_ep_num) coverage[(persona, scenario_id)] += 1 kept_reason_counts[reason] += 1 discard_pct = (discarded / max(generated, 1)) * 100.0 print( f"Generated: {generated} episodes | Kept: {len(kept_records)} | " f"Discarded: {discarded} ({discard_pct:.0f}%) | " f"Skipped (min_reward < {args.min_reward}): {skipped_min_reward}" ) reasons_str = ", ".join(f"{reason}={count}" for reason, count in sorted(kept_reason_counts.items())) print(f"Reasons kept: {reasons_str or 'none'}") print("\nCoverage:") for persona, scenario_id in REQUIRED_COMBINATIONS: print(f" {persona:9s} x {scenario_id:24s} -> {coverage[(persona, scenario_id)]}") _fallback_rate = 100.0 * total_fallback_calls / max(total_live_calls + total_fallback_calls, 1) _verdict = ( "ALL CALLS LIVE - data is real" if _fallback_rate < 5.0 else "WARNING: fallback rate high - check API key and rate limits" ) print(f"\nGemini API health:") print(f" Total live calls : {total_live_calls}") print(f" Total fallback : {total_fallback_calls}") print(f" Fallback rate : {_fallback_rate:.1f}%") print(f" VERDICT: {_verdict}") def main() -> None: parser = argparse.ArgumentParser(description="Generate Parlay training data") parser.add_argument("--episodes", type=int, default=140) parser.add_argument("--output", type=str, default="data/episodes.jsonl") parser.add_argument( "--min-reward", type=float, default=-50.0, help="After grading, do not write episodes with total reward below this (default: -50.0)", ) parser.add_argument( "--quality_filter", action="store_true", help="Discard low-quality episodes instead of writing them", ) parser.add_argument( "--min_efficiency", type=float, default=0.25, help="Min deal_efficiency to keep episode (if quality_filter enabled)", ) parser.add_argument("--google_api_key", type=str, default="") parser.add_argument("--max-turns", type=int, default=14) parser.add_argument( "--inspect", action="store_true", help="Run a fixed 60-episode quality diagnostic; writes data/inspect_run.jsonl", ) parser.add_argument( "--inspect-output", type=str, default="data/inspect_run.jsonl", dest="inspect_output", help="Output path for --inspect mode (default: data/inspect_run.jsonl)", ) parser.add_argument( "--quiet", action="store_true", help="Suppress per-episode and per-call stderr output (final summary always shown)", ) args = parser.parse_args() load_dotenv() logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("google_genai").setLevel(logging.WARNING) logging.getLogger("google_genai.models").setLevel(logging.WARNING) if args.quiet: set_quiet(True) logging.disable(logging.WARNING) if args.google_api_key: os.environ["GOOGLE_API_KEY"] = args.google_api_key if args.inspect: asyncio.run(run_inspect_mode(args)) return output_path = Path(args.output) asyncio.run(run_diversity_pass(args, output_path)) records = [] with open(output_path, encoding="utf-8") as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) total = len(records) deals = sum(1 for record in records if record.get("deal_efficiency", 0) > 0) avg_reward = sum(record.get("reward", 0.0) for record in records) / max(total, 1) print(f"\n{'=' * 50}") print(" GENERATION COMPLETE") print(f"{'=' * 50}") print(f" Episodes in file : {total}") print(f" Deal rate : {deals / max(total, 1) * 100:.1f}% ({deals}/{total})") print(f" Avg total reward : {avg_reward:.2f}") print(f" max_turns used : {args.max_turns}") print(f" Output file : {output_path.resolve()}") print(f"{'=' * 50}\n") if __name__ == "__main__": main()