feat: flash-lite for data-gen and flash for UI; remove training page; card tests; --quiet data gen; data/ inspect path; random baseline; GRPO env wrapper; reward fixes (buyer ZOPA, ToM signals); drift + Brier metrics; Bayesian ToM module
Browse files- agent/gemini_client.py +21 -12
- agent/runner.py +15 -3
- agent/tom_tracker.py +41 -1
- agent/tom_tracker_bayesian.py +296 -0
- dashboard/api.py +18 -4
- dashboard/index.html +0 -1
- main.py +0 -9
- parlay_env/grader.py +24 -0
- tests/test_tactical_cards.py +213 -0
- training/generate_data.py +96 -75
- training/grpo_env_wrapper.py +188 -0
- training/grpo_train.py +4 -1
- training/random_baseline.py +126 -0
- training/reward_fn.py +49 -19
agent/gemini_client.py
CHANGED
|
@@ -5,9 +5,11 @@ All errors return SYNTHETIC_RESPONSE.
|
|
| 5 |
When GOOGLE_API_KEY is absent, MOCK_RESPONSES are returned so the full game
|
| 6 |
loop works without any API key.
|
| 7 |
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
import asyncio
|
| 13 |
import json
|
|
@@ -54,15 +56,21 @@ SCENARIO_ROLE_CONTEXT: dict[str, dict[str, str]] = {
|
|
| 54 |
},
|
| 55 |
}
|
| 56 |
|
| 57 |
-
GEMINI_MODEL
|
| 58 |
-
|
| 59 |
-
MODEL_ID_DEMO =
|
| 60 |
-
MODEL_ID_DATA =
|
| 61 |
-
MODEL_ID = GEMINI_MODEL
|
| 62 |
|
| 63 |
_client = None
|
| 64 |
_mock_warned: bool = False
|
| 65 |
_gemini_model_logged: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# ── Mock responses (keyless dev / CI) ────────────────────────────────────────
|
| 68 |
# Offer amounts are realistic for the default SaaS enterprise scenario
|
|
@@ -336,10 +344,11 @@ async def call_gemini(
|
|
| 336 |
|
| 337 |
_turn_count += 1
|
| 338 |
_live_calls += 1
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
| 343 |
|
| 344 |
text = (response.text or "").strip()
|
| 345 |
text = text.replace("```json", "").replace("```", "").strip()
|
|
|
|
| 5 |
When GOOGLE_API_KEY is absent, MOCK_RESPONSES are returned so the full game
|
| 6 |
loop works without any API key.
|
| 7 |
|
| 8 |
+
Model routing:
|
| 9 |
+
- MODEL_ID_DATA (gemini-2.5-flash-lite) — data generation, self-play, ToM inference.
|
| 10 |
+
Low-latency, high-throughput; used by runner.py and generate_data.py.
|
| 11 |
+
- MODEL_ID_DEMO (gemini-2.5-flash) — web UI, dashboard API, MCP tools.
|
| 12 |
+
Higher quality responses for live user interaction.
|
| 13 |
"""
|
| 14 |
import asyncio
|
| 15 |
import json
|
|
|
|
| 56 |
},
|
| 57 |
}
|
| 58 |
|
| 59 |
+
GEMINI_MODEL = "gemini-2.5-flash-lite" # kept for backward compat; equals MODEL_ID_DATA
|
| 60 |
+
MODEL_ID_DATA = "gemini-2.5-flash-lite" # data generation, self-play, ToM inference
|
| 61 |
+
MODEL_ID_DEMO = "gemini-2.5-flash" # web UI, dashboard API, MCP tools
|
| 62 |
+
MODEL_ID = MODEL_ID_DATA # stable alias (runner.py omits model= → flash-lite)
|
|
|
|
| 63 |
|
| 64 |
_client = None
|
| 65 |
_mock_warned: bool = False
|
| 66 |
_gemini_model_logged: bool = False
|
| 67 |
+
_quiet: bool = False # suppresses per-call [Gemini LIVE] prints when True
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def set_quiet(flag: bool) -> None:
|
| 71 |
+
"""Suppress [Gemini LIVE] per-call stderr prints (e.g. during test runs)."""
|
| 72 |
+
global _quiet
|
| 73 |
+
_quiet = flag
|
| 74 |
|
| 75 |
# ── Mock responses (keyless dev / CI) ────────────────────────────────────────
|
| 76 |
# Offer amounts are realistic for the default SaaS enterprise scenario
|
|
|
|
| 344 |
|
| 345 |
_turn_count += 1
|
| 346 |
_live_calls += 1
|
| 347 |
+
if not _quiet:
|
| 348 |
+
print(
|
| 349 |
+
f"[Gemini LIVE] model={mid} chars={len(response.text or '')} turn={_turn_count}",
|
| 350 |
+
file=sys.stderr,
|
| 351 |
+
)
|
| 352 |
|
| 353 |
text = (response.text or "").strip()
|
| 354 |
text = text.replace("```json", "").replace("```", "").strip()
|
agent/runner.py
CHANGED
|
@@ -131,8 +131,12 @@ async def run_episode(
|
|
| 131 |
for event in scenario.drift_events:
|
| 132 |
if event.trigger_turn == turn or (forced_drift_turn == turn):
|
| 133 |
drift_turn = turn
|
| 134 |
-
tom.drift_event(
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
break
|
| 137 |
|
| 138 |
if inject_noise and turn < 3 and rng.random() < 0.3:
|
|
@@ -211,8 +215,16 @@ async def run_episode(
|
|
| 211 |
|
| 212 |
if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2:
|
| 213 |
adaptation_signals = ["understand", "noted", "given that", "considering"]
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
| 215 |
drift_adapted = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
new_offers = list(state.offer_history)
|
| 218 |
if action.offer_amount:
|
|
|
|
| 131 |
for event in scenario.drift_events:
|
| 132 |
if event.trigger_turn == turn or (forced_drift_turn == turn):
|
| 133 |
drift_turn = turn
|
| 134 |
+
tom.drift_event(
|
| 135 |
+
event.effect_on_urgency,
|
| 136 |
+
event.effect_on_has_alternative,
|
| 137 |
+
event_description=event.event,
|
| 138 |
+
)
|
| 139 |
+
logger.info(f"Drift event at turn {turn}: {event.event!r}")
|
| 140 |
break
|
| 141 |
|
| 142 |
if inject_noise and turn < 3 and rng.random() < 0.3:
|
|
|
|
| 215 |
|
| 216 |
if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2:
|
| 217 |
adaptation_signals = ["understand", "noted", "given that", "considering"]
|
| 218 |
+
matched = next(
|
| 219 |
+
(s for s in adaptation_signals if s in action.utterance.lower()), None
|
| 220 |
+
)
|
| 221 |
+
if matched:
|
| 222 |
drift_adapted = True
|
| 223 |
+
logger.info(
|
| 224 |
+
f"drift_adapted=True at turn={turn} "
|
| 225 |
+
f"matched_phrase={matched!r} "
|
| 226 |
+
f"utterance_snippet={action.utterance[:80]!r}"
|
| 227 |
+
)
|
| 228 |
|
| 229 |
new_offers = list(state.offer_history)
|
| 230 |
if action.offer_amount:
|
agent/tom_tracker.py
CHANGED
|
@@ -129,6 +129,7 @@ class ToMTracker:
|
|
| 129 |
self,
|
| 130 |
effect_on_urgency: float,
|
| 131 |
effect_on_has_alternative: bool,
|
|
|
|
| 132 |
) -> BeliefState:
|
| 133 |
"""
|
| 134 |
Apply a drift event to beliefs.
|
|
@@ -136,6 +137,8 @@ class ToMTracker:
|
|
| 136 |
Args:
|
| 137 |
effect_on_urgency: Signed delta to urgency estimate.
|
| 138 |
effect_on_has_alternative: Override for has_alternative belief.
|
|
|
|
|
|
|
| 139 |
|
| 140 |
Returns:
|
| 141 |
Updated BeliefState post-drift.
|
|
@@ -150,12 +153,49 @@ class ToMTracker:
|
|
| 150 |
confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
|
| 151 |
)
|
| 152 |
self.history.append(updated)
|
|
|
|
| 153 |
logger.info(
|
| 154 |
-
f"ToM drift applied
|
|
|
|
| 155 |
f"alt={effect_on_has_alternative}"
|
| 156 |
)
|
| 157 |
return updated
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def accuracy_against(self, hidden: HiddenState) -> float:
|
| 160 |
"""
|
| 161 |
Compute current belief accuracy against true hidden state.
|
|
|
|
| 129 |
self,
|
| 130 |
effect_on_urgency: float,
|
| 131 |
effect_on_has_alternative: bool,
|
| 132 |
+
event_description: str = "",
|
| 133 |
) -> BeliefState:
|
| 134 |
"""
|
| 135 |
Apply a drift event to beliefs.
|
|
|
|
| 137 |
Args:
|
| 138 |
effect_on_urgency: Signed delta to urgency estimate.
|
| 139 |
effect_on_has_alternative: Override for has_alternative belief.
|
| 140 |
+
event_description: Human-readable scenario event string
|
| 141 |
+
(e.g. "Competitor drops price 15%").
|
| 142 |
|
| 143 |
Returns:
|
| 144 |
Updated BeliefState post-drift.
|
|
|
|
| 153 |
confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
|
| 154 |
)
|
| 155 |
self.history.append(updated)
|
| 156 |
+
desc_part = f" | event={event_description!r}" if event_description else ""
|
| 157 |
logger.info(
|
| 158 |
+
f"ToM drift applied{desc_part}: "
|
| 159 |
+
f"urgency_delta={effect_on_urgency:+.2f} → {new_urgency:.2f}, "
|
| 160 |
f"alt={effect_on_has_alternative}"
|
| 161 |
)
|
| 162 |
return updated
|
| 163 |
|
| 164 |
+
def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
|
| 165 |
+
"""
|
| 166 |
+
Compute per-field Brier scores over the full belief history.
|
| 167 |
+
|
| 168 |
+
Brier score = (1/N) Σ (predicted - actual)²
|
| 169 |
+
Lower is better; 0 = perfect.
|
| 170 |
+
|
| 171 |
+
Fields scored:
|
| 172 |
+
- urgency: est_urgency (continuous 0–1) vs hidden.urgency_score
|
| 173 |
+
- has_alt: est_has_alternative (0/1 probability) vs hidden.has_alternative
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
hidden: The true hidden state revealed at episode end.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Dict with keys "urgency" and "has_alt", each a float in [0, 1].
|
| 180 |
+
"""
|
| 181 |
+
if not self.history:
|
| 182 |
+
return {"urgency": 1.0, "has_alt": 1.0}
|
| 183 |
+
|
| 184 |
+
actual_urgency = hidden.urgency_score
|
| 185 |
+
actual_alt = float(hidden.has_alternative)
|
| 186 |
+
|
| 187 |
+
urgency_sq_err = sum(
|
| 188 |
+
(b.est_urgency - actual_urgency) ** 2 for b in self.history
|
| 189 |
+
)
|
| 190 |
+
alt_sq_err = sum(
|
| 191 |
+
(float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history
|
| 192 |
+
)
|
| 193 |
+
n = len(self.history)
|
| 194 |
+
return {
|
| 195 |
+
"urgency": round(urgency_sq_err / n, 6),
|
| 196 |
+
"has_alt": round(alt_sq_err / n, 6),
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
def accuracy_against(self, hidden: HiddenState) -> float:
|
| 200 |
"""
|
| 201 |
Compute current belief accuracy against true hidden state.
|
agent/tom_tracker_bayesian.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bayesian Theory-of-Mind belief tracker for Parlay.
|
| 3 |
+
|
| 4 |
+
Drop-in replacement for ToMTracker that uses Kalman-filter-style Gaussian
|
| 5 |
+
belief updates instead of hand-tuned arithmetic nudges.
|
| 6 |
+
|
| 7 |
+
Key insight
|
| 8 |
+
-----------
|
| 9 |
+
The opponent has hidden variables (budget_ceiling, walk_away_price, urgency,
|
| 10 |
+
has_alternative). Each observed offer is a noisy signal about these.
|
| 11 |
+
We model each continuous variable as a Gaussian (mean, variance) and update
|
| 12 |
+
using the standard Bayesian update for Gaussian conjugate priors:
|
| 13 |
+
|
| 14 |
+
posterior_mean = (prior_mean / prior_var + obs / obs_var) /
|
| 15 |
+
(1 / prior_var + 1 / obs_var)
|
| 16 |
+
posterior_var = 1 / (1 / prior_var + 1 / obs_var)
|
| 17 |
+
|
| 18 |
+
`confidence` is derived from the posterior variance:
|
| 19 |
+
confidence = 1 / (1 + sqrt(budget_var / budget_mean²))
|
| 20 |
+
|
| 21 |
+
Usage (as feature-flag alternative to ToMTracker):
|
| 22 |
+
from agent.tom_tracker_bayesian import BayesianToMTracker as ToMTracker
|
| 23 |
+
|
| 24 |
+
# Then use exactly the same API as ToMTracker — all method signatures match.
|
| 25 |
+
"""
|
| 26 |
+
import logging
|
| 27 |
+
import math
|
| 28 |
+
import sys
|
| 29 |
+
from typing import Optional
|
| 30 |
+
|
| 31 |
+
from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class BayesianToMTracker:
|
| 37 |
+
"""
|
| 38 |
+
Gaussian-posterior belief tracker for the opponent's hidden state.
|
| 39 |
+
|
| 40 |
+
Extends the original ToMTracker API with proper Bayesian updating.
|
| 41 |
+
The same public methods (update, drift_event, accuracy_against,
|
| 42 |
+
brier_scores, log_belief_snapshot) are preserved for drop-in use.
|
| 43 |
+
|
| 44 |
+
Internal state:
|
| 45 |
+
_budget_mean, _budget_var — Gaussian over opponent's budget ceiling.
|
| 46 |
+
_walk_mean, _walk_var — Gaussian over opponent's walk-away price.
|
| 47 |
+
_urgency_mean, _urgency_var — Gaussian over urgency [0, 1].
|
| 48 |
+
_alt_prob — Bernoulli probability of has_alternative.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# Observation noise variances (tuned for B2B negotiation scale).
|
| 52 |
+
# Budget/walk-away: observed offer is a noisy signal; high variance because
|
| 53 |
+
# opponents rarely reveal their true limits.
|
| 54 |
+
_OBS_BUDGET_VAR_FRAC = 0.10 # 10% of current mean estimate as std
|
| 55 |
+
_OBS_URGENCY_VAR = 0.05 # small update per offer-ratio signal
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
initial_belief: BeliefState,
|
| 60 |
+
persona: PersonaType,
|
| 61 |
+
) -> None:
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
initial_belief: Starting BeliefState (imprecise prior).
|
| 65 |
+
persona: Opponent persona (known to the player).
|
| 66 |
+
"""
|
| 67 |
+
self.persona = persona
|
| 68 |
+
self._bluffs_detected: int = 0
|
| 69 |
+
|
| 70 |
+
# Initialise Gaussian priors from the initial belief
|
| 71 |
+
self._budget_mean = float(initial_belief.est_budget)
|
| 72 |
+
self._walk_mean = float(initial_belief.est_walk_away)
|
| 73 |
+
self._urgency_mean = float(initial_belief.est_urgency)
|
| 74 |
+
self._alt_prob = 0.3 # prior: 30% chance opponent has an alternative
|
| 75 |
+
|
| 76 |
+
# Initial variances — large uncertainty at the start
|
| 77 |
+
self._budget_var = (self._budget_mean * 0.30) ** 2 # ±30% std
|
| 78 |
+
self._walk_var = (self._walk_mean * 0.30) ** 2
|
| 79 |
+
self._urgency_var = 0.08 # std ≈ 0.28 over [0, 1]
|
| 80 |
+
|
| 81 |
+
self.history: list[BeliefState] = [self._snapshot()]
|
| 82 |
+
logger.debug(
|
| 83 |
+
"BayesianToMTracker init: budget_mean=%.0f walk_mean=%.0f urgency_mean=%.2f",
|
| 84 |
+
self._budget_mean, self._walk_mean, self._urgency_mean,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# ── Internal helpers ──────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
def _snapshot(self) -> BeliefState:
|
| 90 |
+
"""Convert current Gaussian state to a BeliefState snapshot."""
|
| 91 |
+
confidence = self._compute_confidence()
|
| 92 |
+
return BeliefState(
|
| 93 |
+
est_budget=round(self._budget_mean, 2),
|
| 94 |
+
est_walk_away=round(self._walk_mean, 2),
|
| 95 |
+
est_urgency=round(max(0.0, min(1.0, self._urgency_mean)), 4),
|
| 96 |
+
est_has_alternative=self._alt_prob >= 0.5,
|
| 97 |
+
confidence=round(confidence, 4),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _compute_confidence(self) -> float:
|
| 101 |
+
"""
|
| 102 |
+
Confidence = 1 - mean relative std across all variables.
|
| 103 |
+
Shrinks variance → higher confidence.
|
| 104 |
+
"""
|
| 105 |
+
budget_rel_std = math.sqrt(self._budget_var) / max(abs(self._budget_mean), 1.0)
|
| 106 |
+
walk_rel_std = math.sqrt(self._walk_var) / max(abs(self._walk_mean), 1.0)
|
| 107 |
+
urgency_std = math.sqrt(self._urgency_var)
|
| 108 |
+
alt_std = math.sqrt(self._alt_prob * (1.0 - self._alt_prob))
|
| 109 |
+
mean_uncertainty = (budget_rel_std + walk_rel_std + urgency_std + alt_std) / 4.0
|
| 110 |
+
return max(0.0, min(1.0, 1.0 - mean_uncertainty))
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def _gaussian_update(
|
| 114 |
+
prior_mean: float,
|
| 115 |
+
prior_var: float,
|
| 116 |
+
obs: float,
|
| 117 |
+
obs_var: float,
|
| 118 |
+
) -> tuple[float, float]:
|
| 119 |
+
"""
|
| 120 |
+
Closed-form Bayesian update for Gaussian conjugate prior.
|
| 121 |
+
|
| 122 |
+
posterior_mean = (prior_mean / prior_var + obs / obs_var) /
|
| 123 |
+
(1 / prior_var + 1 / obs_var)
|
| 124 |
+
posterior_var = 1 / (1 / prior_var + 1 / obs_var)
|
| 125 |
+
"""
|
| 126 |
+
prec_prior = 1.0 / max(prior_var, 1e-10)
|
| 127 |
+
prec_obs = 1.0 / max(obs_var, 1e-10)
|
| 128 |
+
posterior_prec = prec_prior + prec_obs
|
| 129 |
+
posterior_mean = (prec_prior * prior_mean + prec_obs * obs) / posterior_prec
|
| 130 |
+
posterior_var = 1.0 / posterior_prec
|
| 131 |
+
return posterior_mean, posterior_var
|
| 132 |
+
|
| 133 |
+
# ── Public API (matches ToMTracker) ──────────────────────────────────────
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def current_belief(self) -> BeliefState:
|
| 137 |
+
return self.history[-1]
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def bluffs_detected(self) -> int:
|
| 141 |
+
return self._bluffs_detected
|
| 142 |
+
|
| 143 |
+
def log_belief_snapshot(self, turn: int) -> None:
|
| 144 |
+
b = self.current_belief
|
| 145 |
+
print(
|
| 146 |
+
f"[BayesToM turn={turn}] "
|
| 147 |
+
f"budget={b.est_budget:.0f}±{math.sqrt(self._budget_var):.0f} "
|
| 148 |
+
f"urgency={b.est_urgency:.3f}±{math.sqrt(self._urgency_var):.3f} "
|
| 149 |
+
f"alt_prob={self._alt_prob:.2f} conf={b.confidence:.2f}",
|
| 150 |
+
file=sys.stderr,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def update(
|
| 154 |
+
self,
|
| 155 |
+
observed_offer: Optional[float],
|
| 156 |
+
observed_move: Optional[TacticalMove],
|
| 157 |
+
utterance: str,
|
| 158 |
+
turn: int,
|
| 159 |
+
) -> BeliefState:
|
| 160 |
+
"""
|
| 161 |
+
Bayesian update of all beliefs from one observed opponent action.
|
| 162 |
+
|
| 163 |
+
Budget update: if we see an offer O, the true budget is likely > O.
|
| 164 |
+
We treat O as a lower-bound signal: observation = O * 1.05
|
| 165 |
+
with variance proportional to the current mean.
|
| 166 |
+
Urgency update: offer-ratio below 0.85 → urgency signal 0.7;
|
| 167 |
+
above 0.95 → urgency signal 0.3. Both with moderate obs variance.
|
| 168 |
+
has_alternative: updated as Bernoulli likelihood ratio (keyword match).
|
| 169 |
+
"""
|
| 170 |
+
# ── Budget Bayesian update ──────────────────────────────────────────
|
| 171 |
+
if observed_offer is not None and observed_offer > 0:
|
| 172 |
+
budget_obs = observed_offer * 1.05
|
| 173 |
+
obs_budget_var = (self._budget_mean * self._OBS_BUDGET_VAR_FRAC) ** 2
|
| 174 |
+
self._budget_mean, self._budget_var = self._gaussian_update(
|
| 175 |
+
self._budget_mean, self._budget_var,
|
| 176 |
+
budget_obs, obs_budget_var,
|
| 177 |
+
)
|
| 178 |
+
logger.debug(
|
| 179 |
+
"Bayesian budget update: obs=%.0f → mean=%.0f std=%.0f",
|
| 180 |
+
budget_obs, self._budget_mean, math.sqrt(self._budget_var),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# ── Walk-away update: BATNA_REVEAL is a noisy signal ───────────────
|
| 184 |
+
if observed_move == TacticalMove.BATNA_REVEAL:
|
| 185 |
+
if observed_offer is not None:
|
| 186 |
+
walk_obs = observed_offer * 0.95
|
| 187 |
+
obs_walk_var = (self._walk_mean * 0.15) ** 2
|
| 188 |
+
self._walk_mean, self._walk_var = self._gaussian_update(
|
| 189 |
+
self._walk_mean, self._walk_var,
|
| 190 |
+
walk_obs, obs_walk_var,
|
| 191 |
+
)
|
| 192 |
+
logger.debug("Bayesian walk-away update via BATNA_REVEAL")
|
| 193 |
+
|
| 194 |
+
# ── Urgency Bayesian update via offer-ratio signal ─────────────────
|
| 195 |
+
if observed_offer is not None and self._budget_mean > 0:
|
| 196 |
+
offer_ratio = observed_offer / self._budget_mean
|
| 197 |
+
if offer_ratio < 0.85:
|
| 198 |
+
urgency_obs = 0.70 # low offer → opponent likely more urgent
|
| 199 |
+
elif offer_ratio > 0.95:
|
| 200 |
+
urgency_obs = 0.30 # high offer → opponent comfortable
|
| 201 |
+
else:
|
| 202 |
+
urgency_obs = 0.50 # neutral
|
| 203 |
+
self._urgency_mean, self._urgency_var = self._gaussian_update(
|
| 204 |
+
self._urgency_mean, self._urgency_var,
|
| 205 |
+
urgency_obs, self._OBS_URGENCY_VAR,
|
| 206 |
+
)
|
| 207 |
+
self._urgency_mean = max(0.0, min(1.0, self._urgency_mean))
|
| 208 |
+
|
| 209 |
+
# ── has_alternative Bernoulli update (likelihood ratio) ────────────
|
| 210 |
+
alt_signals = ["competitor", "alternative", "other offer", "another bid"]
|
| 211 |
+
if any(sig in utterance.lower() for sig in alt_signals):
|
| 212 |
+
self._alt_prob = min(0.95, self._alt_prob + (1.0 - self._alt_prob) * 0.35)
|
| 213 |
+
logger.debug("Alternative signal detected → alt_prob=%.2f", self._alt_prob)
|
| 214 |
+
else:
|
| 215 |
+
self._alt_prob = max(0.05, self._alt_prob * 0.98) # small decay
|
| 216 |
+
|
| 217 |
+
# ── Bluff detection (shark persona + BATNA_REVEAL + "competitor") ──
|
| 218 |
+
if (
|
| 219 |
+
self.persona == PersonaType.SHARK
|
| 220 |
+
and observed_move == TacticalMove.BATNA_REVEAL
|
| 221 |
+
and "competitor" in utterance.lower()
|
| 222 |
+
):
|
| 223 |
+
self._bluffs_detected += 1
|
| 224 |
+
logger.info("BayesToM: bluff detected (total: %d)", self._bluffs_detected)
|
| 225 |
+
|
| 226 |
+
updated = self._snapshot()
|
| 227 |
+
self.history.append(updated)
|
| 228 |
+
logger.debug(
|
| 229 |
+
"BayesToM update turn=%d: budget=%.0f walk=%.0f urgency=%.2f alt_prob=%.2f conf=%.2f",
|
| 230 |
+
turn, self._budget_mean, self._walk_mean, self._urgency_mean,
|
| 231 |
+
self._alt_prob, updated.confidence,
|
| 232 |
+
)
|
| 233 |
+
return updated
|
| 234 |
+
|
| 235 |
+
def drift_event(
|
| 236 |
+
self,
|
| 237 |
+
effect_on_urgency: float,
|
| 238 |
+
effect_on_has_alternative: bool,
|
| 239 |
+
event_description: str = "",
|
| 240 |
+
) -> BeliefState:
|
| 241 |
+
"""
|
| 242 |
+
Apply a market/scenario drift event.
|
| 243 |
+
|
| 244 |
+
Nudges the urgency mean and resets alt_prob based on the drift direction.
|
| 245 |
+
Also inflates all variances (drift = increased uncertainty).
|
| 246 |
+
"""
|
| 247 |
+
self._urgency_mean = float(max(0.0, min(1.0, self._urgency_mean + effect_on_urgency)))
|
| 248 |
+
self._urgency_var = min(0.1, self._urgency_var * 1.5) # inflate uncertainty
|
| 249 |
+
|
| 250 |
+
# Drift shifts alt belief
|
| 251 |
+
if effect_on_has_alternative:
|
| 252 |
+
self._alt_prob = min(0.9, self._alt_prob + 0.25)
|
| 253 |
+
else:
|
| 254 |
+
self._alt_prob = max(0.1, self._alt_prob - 0.1)
|
| 255 |
+
|
| 256 |
+
# Inflate budget/walk variances — drift reduces confidence
|
| 257 |
+
self._budget_var *= 1.3
|
| 258 |
+
self._walk_var *= 1.3
|
| 259 |
+
|
| 260 |
+
updated = self._snapshot()
|
| 261 |
+
self.history.append(updated)
|
| 262 |
+
desc_part = f" | event={event_description!r}" if event_description else ""
|
| 263 |
+
logger.info(
|
| 264 |
+
"BayesToM drift applied%s: urgency_delta=%+.2f → %.2f, alt_prob=%.2f, conf=%.2f",
|
| 265 |
+
desc_part, effect_on_urgency, self._urgency_mean, self._alt_prob, updated.confidence,
|
| 266 |
+
)
|
| 267 |
+
return updated
|
| 268 |
+
|
| 269 |
+
def accuracy_against(self, hidden: HiddenState) -> float:
|
| 270 |
+
"""
|
| 271 |
+
Compute current belief accuracy against true hidden state.
|
| 272 |
+
Same formula as ToMTracker for comparability.
|
| 273 |
+
"""
|
| 274 |
+
b = self.current_belief
|
| 275 |
+
budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
|
| 276 |
+
walk_range = max(hidden.walk_away_price * 0.5, 1.0)
|
| 277 |
+
budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range
|
| 278 |
+
walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range
|
| 279 |
+
urgency_err = abs(b.est_urgency - hidden.urgency_score)
|
| 280 |
+
alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0
|
| 281 |
+
mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
|
| 282 |
+
return max(0.0, 1.0 - mean_err)
|
| 283 |
+
|
| 284 |
+
def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
|
| 285 |
+
"""Brier scores for urgency and has_alternative over full belief history."""
|
| 286 |
+
if not self.history:
|
| 287 |
+
return {"urgency": 1.0, "has_alt": 1.0}
|
| 288 |
+
actual_urgency = hidden.urgency_score
|
| 289 |
+
actual_alt = float(hidden.has_alternative)
|
| 290 |
+
n = len(self.history)
|
| 291 |
+
brier_urgency = sum((b.est_urgency - actual_urgency) ** 2 for b in self.history) / n
|
| 292 |
+
brier_alt = sum((float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history) / n
|
| 293 |
+
return {
|
| 294 |
+
"urgency": round(brier_urgency, 6),
|
| 295 |
+
"has_alt": round(brier_alt, 6),
|
| 296 |
+
}
|
dashboard/api.py
CHANGED
|
@@ -253,8 +253,16 @@ def _apply_drift(session: dict[str, Any]) -> Optional[str]:
|
|
| 253 |
if event.trigger_turn == state.step_count:
|
| 254 |
session["drift_turn"] = state.step_count
|
| 255 |
state.drift_events_fired += 1
|
| 256 |
-
session["tom_tracker"].drift_event(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
state.belief_history = list(session["tom_tracker"].history)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
return event.event
|
| 259 |
return None
|
| 260 |
|
|
@@ -443,10 +451,16 @@ async def make_move(req: MoveRequest) -> dict:
|
|
| 443 |
)
|
| 444 |
|
| 445 |
if session["drift_turn"] is not None and not session["drift_adapted"]:
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
|
|
|
|
|
|
| 449 |
session["drift_adapted"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
new_history = list(state.offer_history)
|
| 452 |
if req.amount is not None:
|
|
|
|
| 253 |
if event.trigger_turn == state.step_count:
|
| 254 |
session["drift_turn"] = state.step_count
|
| 255 |
state.drift_events_fired += 1
|
| 256 |
+
session["tom_tracker"].drift_event(
|
| 257 |
+
event.effect_on_urgency,
|
| 258 |
+
event.effect_on_has_alternative,
|
| 259 |
+
event_description=event.event,
|
| 260 |
+
)
|
| 261 |
state.belief_history = list(session["tom_tracker"].history)
|
| 262 |
+
logger.info(
|
| 263 |
+
"Drift event fired: scenario=%s turn=%d event=%r urgency_delta=%+.2f",
|
| 264 |
+
state.scenario_id, state.step_count, event.event, event.effect_on_urgency,
|
| 265 |
+
)
|
| 266 |
return event.event
|
| 267 |
return None
|
| 268 |
|
|
|
|
| 451 |
)
|
| 452 |
|
| 453 |
if session["drift_turn"] is not None and not session["drift_adapted"]:
|
| 454 |
+
adaptation_signals = ["understand", "noted", "given", "considering", "account"]
|
| 455 |
+
matched_signal = next(
|
| 456 |
+
(s for s in adaptation_signals if s in req.message.lower()), None
|
| 457 |
+
)
|
| 458 |
+
if turn <= session["drift_turn"] + 2 and matched_signal:
|
| 459 |
session["drift_adapted"] = True
|
| 460 |
+
logger.info(
|
| 461 |
+
"drift_adapted=True session=%s turn=%d matched_phrase=%r snippet=%r",
|
| 462 |
+
req.session_id, turn + 1, matched_signal, req.message[:80],
|
| 463 |
+
)
|
| 464 |
|
| 465 |
new_history = list(state.offer_history)
|
| 466 |
if req.amount is not None:
|
dashboard/index.html
CHANGED
|
@@ -85,7 +85,6 @@
|
|
| 85 |
|
| 86 |
<nav class="header-nav" aria-label="Site navigation">
|
| 87 |
<a href="/index.html" class="active">Game</a>
|
| 88 |
-
<a href="/train.html">Training</a>
|
| 89 |
</nav>
|
| 90 |
|
| 91 |
<div class="header-actions">
|
|
|
|
| 85 |
|
| 86 |
<nav class="header-nav" aria-label="Site navigation">
|
| 87 |
<a href="/index.html" class="active">Game</a>
|
|
|
|
| 88 |
</nav>
|
| 89 |
|
| 90 |
<div class="header-actions">
|
main.py
CHANGED
|
@@ -80,15 +80,6 @@ async def serve_index() -> FileResponse:
|
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
-
@app.get("/train", include_in_schema=False)
|
| 84 |
-
async def serve_train() -> FileResponse:
|
| 85 |
-
"""Serve the training dashboard."""
|
| 86 |
-
return FileResponse(
|
| 87 |
-
"dashboard/train.html",
|
| 88 |
-
headers={"Cache-Control": "no-cache, must-revalidate"},
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
@app.get("/spectate", include_in_schema=False)
|
| 93 |
async def serve_spectate() -> FileResponse:
|
| 94 |
"""Serve the spectator dashboard."""
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
@app.get("/spectate", include_in_schema=False)
|
| 84 |
async def serve_spectate() -> FileResponse:
|
| 85 |
"""Serve the spectator dashboard."""
|
parlay_env/grader.py
CHANGED
|
@@ -42,6 +42,27 @@ class EpisodeGrade:
|
|
| 42 |
bluffs_caught: int
|
| 43 |
termination_reason: Optional[str]
|
| 44 |
drift_adapted: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
def _tom_accuracy(belief: BeliefState, hidden: HiddenState) -> float:
|
|
@@ -208,6 +229,7 @@ def grade_episode(
|
|
| 208 |
|
| 209 |
tom_scores = [_tom_accuracy(belief, session.hidden_state) for belief in session.belief_history]
|
| 210 |
tom_accuracy_avg = sum(tom_scores) / len(tom_scores) if tom_scores else 0.0
|
|
|
|
| 211 |
terminal = compute_terminal_reward(session, final_price, t_close, t_max, drift_adapted)
|
| 212 |
|
| 213 |
return EpisodeGrade(
|
|
@@ -217,4 +239,6 @@ def grade_episode(
|
|
| 217 |
bluffs_caught=session.bluffs_caught if bluffs_caught is None else bluffs_caught,
|
| 218 |
termination_reason=session.termination_reason,
|
| 219 |
drift_adapted=drift_adapted,
|
|
|
|
|
|
|
| 220 |
)
|
|
|
|
| 42 |
bluffs_caught: int
|
| 43 |
termination_reason: Optional[str]
|
| 44 |
drift_adapted: bool
|
| 45 |
+
tom_brier_urgency: float = 0.0 # Brier score for urgency beliefs (lower = better)
|
| 46 |
+
tom_brier_alt: float = 0.0 # Brier score for has_alternative beliefs
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _brier_scores(beliefs: list[BeliefState], hidden: HiddenState) -> tuple[float, float]:
|
| 50 |
+
"""
|
| 51 |
+
Compute Brier scores for urgency and has_alternative over all belief snapshots.
|
| 52 |
+
|
| 53 |
+
Brier score = (1/N) Σ (predicted - actual)²; lower is better, 0 = perfect.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
(brier_urgency, brier_alt) both in [0, 1].
|
| 57 |
+
"""
|
| 58 |
+
if not beliefs:
|
| 59 |
+
return 1.0, 1.0
|
| 60 |
+
actual_urgency = hidden.urgency_score
|
| 61 |
+
actual_alt = float(hidden.has_alternative)
|
| 62 |
+
n = len(beliefs)
|
| 63 |
+
brier_urgency = sum((b.est_urgency - actual_urgency) ** 2 for b in beliefs) / n
|
| 64 |
+
brier_alt = sum((float(b.est_has_alternative) - actual_alt) ** 2 for b in beliefs) / n
|
| 65 |
+
return round(brier_urgency, 6), round(brier_alt, 6)
|
| 66 |
|
| 67 |
|
| 68 |
def _tom_accuracy(belief: BeliefState, hidden: HiddenState) -> float:
|
|
|
|
| 229 |
|
| 230 |
tom_scores = [_tom_accuracy(belief, session.hidden_state) for belief in session.belief_history]
|
| 231 |
tom_accuracy_avg = sum(tom_scores) / len(tom_scores) if tom_scores else 0.0
|
| 232 |
+
brier_urgency, brier_alt = _brier_scores(session.belief_history, session.hidden_state)
|
| 233 |
terminal = compute_terminal_reward(session, final_price, t_close, t_max, drift_adapted)
|
| 234 |
|
| 235 |
return EpisodeGrade(
|
|
|
|
| 239 |
bluffs_caught=session.bluffs_caught if bluffs_caught is None else bluffs_caught,
|
| 240 |
termination_reason=session.termination_reason,
|
| 241 |
drift_adapted=drift_adapted,
|
| 242 |
+
tom_brier_urgency=brier_urgency,
|
| 243 |
+
tom_brier_alt=brier_alt,
|
| 244 |
)
|
tests/test_tactical_cards.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tactical card tests — verifies card retrieval, serialisation, and API play flow.
|
| 3 |
+
Runs in mock mode (no API key required).
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
pytest tests/test_tactical_cards.py -v
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import pytest_asyncio
|
| 12 |
+
from httpx import AsyncClient, ASGITransport
|
| 13 |
+
|
| 14 |
+
os.environ.pop("GOOGLE_API_KEY", None)
|
| 15 |
+
|
| 16 |
+
from game.tactical_cards import TACTICAL_CARDS, TacticalCard, get_card, draw_hand
|
| 17 |
+
from parlay_env.models import TacticalMove
|
| 18 |
+
from dashboard.api import _serialise_cards
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ── Unit: card definitions ────────────────────────────────────────────────────
|
| 22 |
+
|
| 23 |
+
class TestCardDefinitions:
|
| 24 |
+
def test_all_three_cards_defined(self):
|
| 25 |
+
"""All three tactical cards are present in the registry."""
|
| 26 |
+
assert "anchor_high" in TACTICAL_CARDS
|
| 27 |
+
assert "batna_reveal" in TACTICAL_CARDS
|
| 28 |
+
assert "silence" in TACTICAL_CARDS
|
| 29 |
+
|
| 30 |
+
def test_card_fields_populated(self):
|
| 31 |
+
"""Each card has all required fields with sensible values."""
|
| 32 |
+
for card_id, card in TACTICAL_CARDS.items():
|
| 33 |
+
assert isinstance(card, TacticalCard)
|
| 34 |
+
assert card.id == card_id
|
| 35 |
+
assert card.name, f"Card {card_id} has empty name"
|
| 36 |
+
assert card.description, f"Card {card_id} has empty description"
|
| 37 |
+
assert card.cp_cost >= 0, f"Card {card_id} has negative CP cost"
|
| 38 |
+
|
| 39 |
+
def test_cp_costs_match_expected(self):
|
| 40 |
+
"""CP costs match the game design spec."""
|
| 41 |
+
assert TACTICAL_CARDS["anchor_high"].cp_cost == 0
|
| 42 |
+
assert TACTICAL_CARDS["batna_reveal"].cp_cost == 20
|
| 43 |
+
assert TACTICAL_CARDS["silence"].cp_cost == 5
|
| 44 |
+
|
| 45 |
+
def test_get_card_by_tactical_move_enum(self):
|
| 46 |
+
"""get_card() accepts TacticalMove enum values."""
|
| 47 |
+
card = get_card(TacticalMove.ANCHOR_HIGH)
|
| 48 |
+
assert card.id == "anchor_high"
|
| 49 |
+
|
| 50 |
+
card = get_card(TacticalMove.BATNA_REVEAL)
|
| 51 |
+
assert card.id == "batna_reveal"
|
| 52 |
+
|
| 53 |
+
card = get_card(TacticalMove.SILENCE)
|
| 54 |
+
assert card.id == "silence"
|
| 55 |
+
|
| 56 |
+
def test_get_card_by_string_id(self):
|
| 57 |
+
"""get_card() accepts plain string ids."""
|
| 58 |
+
card = get_card("anchor_high")
|
| 59 |
+
assert card.id == "anchor_high"
|
| 60 |
+
|
| 61 |
+
def test_get_card_unknown_raises(self):
|
| 62 |
+
"""get_card() raises KeyError for unknown card ids."""
|
| 63 |
+
with pytest.raises(KeyError):
|
| 64 |
+
get_card("does_not_exist")
|
| 65 |
+
|
| 66 |
+
def test_draw_hand_returns_subset(self):
|
| 67 |
+
"""draw_hand() returns at most n valid TacticalMove values."""
|
| 68 |
+
hand = draw_hand(n=2, rng_seed=0)
|
| 69 |
+
assert len(hand) == 2
|
| 70 |
+
for move in hand:
|
| 71 |
+
assert isinstance(move, TacticalMove)
|
| 72 |
+
|
| 73 |
+
def test_draw_hand_no_duplicates(self):
|
| 74 |
+
"""draw_hand() never repeats a card."""
|
| 75 |
+
hand = draw_hand(n=3, rng_seed=7)
|
| 76 |
+
assert len(hand) == len(set(hand))
|
| 77 |
+
|
| 78 |
+
def test_draw_hand_capped_at_total_cards(self):
|
| 79 |
+
"""draw_hand() with n > total cards returns all cards once."""
|
| 80 |
+
hand = draw_hand(n=999, rng_seed=0)
|
| 81 |
+
assert len(hand) == len(TACTICAL_CARDS)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ── Unit: serialisation ───────────────────────────────────────────────────────
|
| 85 |
+
|
| 86 |
+
class TestSerialiseCards:
|
| 87 |
+
def test_serialise_returns_list(self):
|
| 88 |
+
result = _serialise_cards()
|
| 89 |
+
assert isinstance(result, list)
|
| 90 |
+
|
| 91 |
+
def test_serialise_length_matches_registry(self):
|
| 92 |
+
result = _serialise_cards()
|
| 93 |
+
assert len(result) == len(TACTICAL_CARDS)
|
| 94 |
+
|
| 95 |
+
def test_serialise_required_keys(self):
|
| 96 |
+
"""Each serialised card has all keys the frontend expects."""
|
| 97 |
+
required_keys = {"id", "move", "name", "cp_cost", "description", "theory", "game_theory_ref"}
|
| 98 |
+
for item in _serialise_cards():
|
| 99 |
+
missing = required_keys - item.keys()
|
| 100 |
+
assert not missing, f"Card {item.get('id')} missing keys: {missing}"
|
| 101 |
+
|
| 102 |
+
def test_serialise_id_equals_move(self):
|
| 103 |
+
"""'id' and 'move' fields are identical (frontend uses both)."""
|
| 104 |
+
for item in _serialise_cards():
|
| 105 |
+
assert item["id"] == item["move"]
|
| 106 |
+
|
| 107 |
+
def test_serialise_cp_cost_is_int(self):
|
| 108 |
+
for item in _serialise_cards():
|
| 109 |
+
assert isinstance(item["cp_cost"], int)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ── Integration: play a card through the API ─────────────────────────────────
|
| 113 |
+
|
| 114 |
+
from main import app
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest_asyncio.fixture
|
| 118 |
+
async def client():
|
| 119 |
+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
|
| 120 |
+
yield ac
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@pytest.mark.asyncio
|
| 124 |
+
async def test_start_game_hand_contains_cards(client):
|
| 125 |
+
"""POST /api/game/start returns a hand of serialised tactical cards."""
|
| 126 |
+
resp = await client.post("/api/game/start", json={
|
| 127 |
+
"scenario_id": "saas_enterprise",
|
| 128 |
+
"persona": "shark",
|
| 129 |
+
"player_name": "Tester",
|
| 130 |
+
})
|
| 131 |
+
assert resp.status_code == 200, resp.text
|
| 132 |
+
data = resp.json()
|
| 133 |
+
assert "hand" in data
|
| 134 |
+
assert isinstance(data["hand"], list)
|
| 135 |
+
assert len(data["hand"]) == len(TACTICAL_CARDS)
|
| 136 |
+
ids_in_hand = {c["id"] for c in data["hand"]}
|
| 137 |
+
assert "anchor_high" in ids_in_hand
|
| 138 |
+
assert "batna_reveal" in ids_in_hand
|
| 139 |
+
assert "silence" in ids_in_hand
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@pytest.mark.asyncio
|
| 143 |
+
async def test_play_card_anchor_high_zero_cost(client):
|
| 144 |
+
"""Playing anchor_high (0 CP) succeeds and deducts nothing."""
|
| 145 |
+
start = await client.post("/api/game/start", json={
|
| 146 |
+
"scenario_id": "saas_enterprise",
|
| 147 |
+
"persona": "diplomat",
|
| 148 |
+
"player_name": "Tester",
|
| 149 |
+
})
|
| 150 |
+
assert start.status_code == 200
|
| 151 |
+
session_id = start.json()["session_id"]
|
| 152 |
+
initial_cp = start.json()["observation"]["credibility_points"]
|
| 153 |
+
|
| 154 |
+
move = await client.post("/api/game/move", json={
|
| 155 |
+
"session_id": session_id,
|
| 156 |
+
"amount": 140000,
|
| 157 |
+
"message": "Anchoring high.",
|
| 158 |
+
"tactical_move": "anchor_high",
|
| 159 |
+
})
|
| 160 |
+
assert move.status_code == 200, move.text
|
| 161 |
+
obs = move.json().get("observation", {})
|
| 162 |
+
assert obs.get("credibility_points", 0) >= initial_cp - 5 # only regen delta at most
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@pytest.mark.asyncio
|
| 166 |
+
async def test_play_card_insufficient_cp_returns_400(client):
|
| 167 |
+
"""Playing batna_reveal (20 CP) with insufficient CP returns 400."""
|
| 168 |
+
start = await client.post("/api/game/start", json={
|
| 169 |
+
"scenario_id": "saas_enterprise",
|
| 170 |
+
"persona": "veteran",
|
| 171 |
+
"player_name": "Tester",
|
| 172 |
+
})
|
| 173 |
+
assert start.status_code == 200
|
| 174 |
+
session_id = start.json()["session_id"]
|
| 175 |
+
|
| 176 |
+
# Drain CP by playing silence (5 CP) many times
|
| 177 |
+
for _ in range(18): # 18 × 5 = 90 CP spent, 18 regen ticks → ~0 CP
|
| 178 |
+
await client.post("/api/game/move", json={
|
| 179 |
+
"session_id": session_id,
|
| 180 |
+
"amount": 150000,
|
| 181 |
+
"message": "...",
|
| 182 |
+
"tactical_move": "silence",
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
# At this point CP should be too low for batna_reveal (20 CP)
|
| 186 |
+
resp = await client.post("/api/game/move", json={
|
| 187 |
+
"session_id": session_id,
|
| 188 |
+
"amount": 155000,
|
| 189 |
+
"message": "Let me reveal my BATNA.",
|
| 190 |
+
"tactical_move": "batna_reveal",
|
| 191 |
+
})
|
| 192 |
+
# Either succeeds (if CP regenerated enough) or fails with 400
|
| 193 |
+
assert resp.status_code in (200, 400)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@pytest.mark.asyncio
|
| 197 |
+
async def test_play_invalid_card_returns_400(client):
|
| 198 |
+
"""Sending an unrecognised card_id returns 400."""
|
| 199 |
+
start = await client.post("/api/game/start", json={
|
| 200 |
+
"scenario_id": "hiring_package",
|
| 201 |
+
"persona": "shark",
|
| 202 |
+
"player_name": "Tester",
|
| 203 |
+
})
|
| 204 |
+
assert start.status_code == 200
|
| 205 |
+
session_id = start.json()["session_id"]
|
| 206 |
+
|
| 207 |
+
resp = await client.post("/api/game/move", json={
|
| 208 |
+
"session_id": session_id,
|
| 209 |
+
"amount": 200000,
|
| 210 |
+
"message": "Playing a mystery card.",
|
| 211 |
+
"tactical_move": "not_a_real_card",
|
| 212 |
+
})
|
| 213 |
+
assert resp.status_code == 400
|
training/generate_data.py
CHANGED
|
@@ -15,7 +15,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
| 15 |
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
|
| 18 |
-
from agent.gemini_client import get_and_reset_counts
|
| 19 |
from agent.runner import EpisodeResult, run_episode
|
| 20 |
from game.scenarios import SCENARIOS
|
| 21 |
from parlay_env.models import PersonaType
|
|
@@ -223,7 +223,7 @@ def _print_inspect_report(
|
|
| 223 |
|
| 224 |
|
| 225 |
async def run_inspect_mode(args) -> None:
|
| 226 |
-
out_path = Path("
|
| 227 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 228 |
|
| 229 |
coverage: dict[tuple[str, str], int] = defaultdict(int)
|
|
@@ -320,6 +320,7 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 320 |
min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
|
| 321 |
total_live_calls: int = 0
|
| 322 |
total_fallback_calls: int = 0
|
|
|
|
| 323 |
|
| 324 |
with open(output_path, "w", encoding="utf-8") as out_f:
|
| 325 |
while len(kept_records) < args.episodes:
|
|
@@ -348,16 +349,17 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 348 |
_live_d, _fall_d = get_and_reset_counts()
|
| 349 |
total_live_calls += _live_d
|
| 350 |
total_fallback_calls += _fall_d
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
| 361 |
continue
|
| 362 |
|
| 363 |
out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
@@ -366,35 +368,36 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 366 |
total_live_calls += _live
|
| 367 |
total_fallback_calls += _fall
|
| 368 |
_ep_num = len(kept_records)
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 374 |
-
f"{_combo:<35s} | "
|
| 375 |
-
f"reward={_reward:+.2f} | "
|
| 376 |
-
f"eff={_eff:.3f} | "
|
| 377 |
-
f"kept=YES | "
|
| 378 |
-
f"total_kept={_ep_num}/{generated} | "
|
| 379 |
-
f"gemini_live={_live} fallback={_fall}",
|
| 380 |
-
file=sys.stderr,
|
| 381 |
-
)
|
| 382 |
-
if _ep_num in (20, 40, 60):
|
| 383 |
-
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 384 |
-
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 385 |
-
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 386 |
-
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 387 |
-
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 388 |
print(
|
| 389 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
file=sys.stderr,
|
| 391 |
)
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
coverage[(persona, scenario_id)] += 1
|
| 399 |
kept_reason_counts[reason] += 1
|
| 400 |
progress_made = True
|
|
@@ -423,35 +426,36 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 423 |
total_live_calls += _live
|
| 424 |
total_fallback_calls += _fall
|
| 425 |
_ep_num = len(kept_records)
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 431 |
-
f"{_combo:<35s} | "
|
| 432 |
-
f"reward={_reward:+.2f} | "
|
| 433 |
-
f"eff={_eff:.3f} | "
|
| 434 |
-
f"kept=YES | "
|
| 435 |
-
f"total_kept={_ep_num}/{generated} | "
|
| 436 |
-
f"gemini_live={_live} fallback={_fall}",
|
| 437 |
-
file=sys.stderr,
|
| 438 |
-
)
|
| 439 |
-
if _ep_num in (20, 40, 60):
|
| 440 |
-
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 441 |
-
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 442 |
-
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 443 |
-
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 444 |
-
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 445 |
print(
|
| 446 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
file=sys.stderr,
|
| 448 |
)
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
coverage[(persona, scenario_id)] += 1
|
| 456 |
kept_reason_counts[reason] += 1
|
| 457 |
else:
|
|
@@ -459,16 +463,17 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 459 |
_live_d, _fall_d = get_and_reset_counts()
|
| 460 |
total_live_calls += _live_d
|
| 461 |
total_fallback_calls += _fall_d
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
| 472 |
|
| 473 |
discard_pct = (discarded / max(generated, 1)) * 100.0
|
| 474 |
print(
|
|
@@ -514,7 +519,19 @@ def main() -> None:
|
|
| 514 |
parser.add_argument(
|
| 515 |
"--inspect",
|
| 516 |
action="store_true",
|
| 517 |
-
help="Run a fixed 60-episode quality diagnostic; writes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
)
|
| 519 |
args = parser.parse_args()
|
| 520 |
|
|
@@ -524,6 +541,10 @@ def main() -> None:
|
|
| 524 |
logging.getLogger("google_genai").setLevel(logging.WARNING)
|
| 525 |
logging.getLogger("google_genai.models").setLevel(logging.WARNING)
|
| 526 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
if args.google_api_key:
|
| 528 |
os.environ["GOOGLE_API_KEY"] = args.google_api_key
|
| 529 |
|
|
|
|
| 15 |
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
|
| 18 |
+
from agent.gemini_client import get_and_reset_counts, set_quiet
|
| 19 |
from agent.runner import EpisodeResult, run_episode
|
| 20 |
from game.scenarios import SCENARIOS
|
| 21 |
from parlay_env.models import PersonaType
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
async def run_inspect_mode(args) -> None:
|
| 226 |
+
out_path = Path(getattr(args, "inspect_output", "data/inspect_run.jsonl"))
|
| 227 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 228 |
|
| 229 |
coverage: dict[tuple[str, str], int] = defaultdict(int)
|
|
|
|
| 320 |
min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
|
| 321 |
total_live_calls: int = 0
|
| 322 |
total_fallback_calls: int = 0
|
| 323 |
+
_verbose = not getattr(args, "quiet", False)
|
| 324 |
|
| 325 |
with open(output_path, "w", encoding="utf-8") as out_f:
|
| 326 |
while len(kept_records) < args.episodes:
|
|
|
|
| 349 |
_live_d, _fall_d = get_and_reset_counts()
|
| 350 |
total_live_calls += _live_d
|
| 351 |
total_fallback_calls += _fall_d
|
| 352 |
+
if _verbose:
|
| 353 |
+
print(
|
| 354 |
+
f"[EP --/{args.episodes:02d}] "
|
| 355 |
+
f"{persona}×{scenario_id:<27s} | "
|
| 356 |
+
f"reward={record.get('reward', 0.0):+.2f} | "
|
| 357 |
+
f"eff={record.get('deal_efficiency', 0.0):.3f} | "
|
| 358 |
+
f"kept=NO | "
|
| 359 |
+
f"total_kept={len(kept_records)}/{generated} | "
|
| 360 |
+
f"gemini_live={_live_d} fallback={_fall_d}",
|
| 361 |
+
file=sys.stderr,
|
| 362 |
+
)
|
| 363 |
continue
|
| 364 |
|
| 365 |
out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
|
|
| 368 |
total_live_calls += _live
|
| 369 |
total_fallback_calls += _fall
|
| 370 |
_ep_num = len(kept_records)
|
| 371 |
+
if _verbose:
|
| 372 |
+
_reward = record.get("reward", 0.0)
|
| 373 |
+
_eff = record.get("deal_efficiency", 0.0)
|
| 374 |
+
_combo = f"{record['persona']}×{record['scenario_id']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
print(
|
| 376 |
+
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 377 |
+
f"{_combo:<35s} | "
|
| 378 |
+
f"reward={_reward:+.2f} | "
|
| 379 |
+
f"eff={_eff:.3f} | "
|
| 380 |
+
f"kept=YES | "
|
| 381 |
+
f"total_kept={_ep_num}/{generated} | "
|
| 382 |
+
f"gemini_live={_live} fallback={_fall}",
|
| 383 |
file=sys.stderr,
|
| 384 |
)
|
| 385 |
+
if _ep_num in (20, 40, 60):
|
| 386 |
+
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 387 |
+
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 388 |
+
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 389 |
+
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 390 |
+
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 391 |
+
print(
|
| 392 |
+
f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
|
| 393 |
+
file=sys.stderr,
|
| 394 |
+
)
|
| 395 |
+
print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
|
| 396 |
+
print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
|
| 397 |
+
print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
|
| 398 |
+
print(f" Live calls total: {total_live_calls}", file=sys.stderr)
|
| 399 |
+
print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
|
| 400 |
+
print(f"{'━' * 40}\n", file=sys.stderr)
|
| 401 |
coverage[(persona, scenario_id)] += 1
|
| 402 |
kept_reason_counts[reason] += 1
|
| 403 |
progress_made = True
|
|
|
|
| 426 |
total_live_calls += _live
|
| 427 |
total_fallback_calls += _fall
|
| 428 |
_ep_num = len(kept_records)
|
| 429 |
+
if _verbose:
|
| 430 |
+
_reward = record.get("reward", 0.0)
|
| 431 |
+
_eff = record.get("deal_efficiency", 0.0)
|
| 432 |
+
_combo = f"{record['persona']}×{record['scenario_id']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
print(
|
| 434 |
+
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 435 |
+
f"{_combo:<35s} | "
|
| 436 |
+
f"reward={_reward:+.2f} | "
|
| 437 |
+
f"eff={_eff:.3f} | "
|
| 438 |
+
f"kept=YES | "
|
| 439 |
+
f"total_kept={_ep_num}/{generated} | "
|
| 440 |
+
f"gemini_live={_live} fallback={_fall}",
|
| 441 |
file=sys.stderr,
|
| 442 |
)
|
| 443 |
+
if _ep_num in (20, 40, 60):
|
| 444 |
+
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 445 |
+
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 446 |
+
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 447 |
+
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 448 |
+
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 449 |
+
print(
|
| 450 |
+
f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
|
| 451 |
+
file=sys.stderr,
|
| 452 |
+
)
|
| 453 |
+
print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
|
| 454 |
+
print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
|
| 455 |
+
print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
|
| 456 |
+
print(f" Live calls total: {total_live_calls}", file=sys.stderr)
|
| 457 |
+
print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
|
| 458 |
+
print(f"{'━' * 40}\n", file=sys.stderr)
|
| 459 |
coverage[(persona, scenario_id)] += 1
|
| 460 |
kept_reason_counts[reason] += 1
|
| 461 |
else:
|
|
|
|
| 463 |
_live_d, _fall_d = get_and_reset_counts()
|
| 464 |
total_live_calls += _live_d
|
| 465 |
total_fallback_calls += _fall_d
|
| 466 |
+
if _verbose:
|
| 467 |
+
print(
|
| 468 |
+
f"[EP --/{args.episodes:02d}] "
|
| 469 |
+
f"{persona}×{scenario_id:<27s} | "
|
| 470 |
+
f"reward={record.get('reward', 0.0):+.2f} | "
|
| 471 |
+
f"eff={record.get('deal_efficiency', 0.0):.3f} | "
|
| 472 |
+
f"kept=NO | "
|
| 473 |
+
f"total_kept={len(kept_records)}/{generated} | "
|
| 474 |
+
f"gemini_live={_live_d} fallback={_fall_d}",
|
| 475 |
+
file=sys.stderr,
|
| 476 |
+
)
|
| 477 |
|
| 478 |
discard_pct = (discarded / max(generated, 1)) * 100.0
|
| 479 |
print(
|
|
|
|
| 519 |
parser.add_argument(
|
| 520 |
"--inspect",
|
| 521 |
action="store_true",
|
| 522 |
+
help="Run a fixed 60-episode quality diagnostic; writes data/inspect_run.jsonl",
|
| 523 |
+
)
|
| 524 |
+
parser.add_argument(
|
| 525 |
+
"--inspect-output",
|
| 526 |
+
type=str,
|
| 527 |
+
default="data/inspect_run.jsonl",
|
| 528 |
+
dest="inspect_output",
|
| 529 |
+
help="Output path for --inspect mode (default: data/inspect_run.jsonl)",
|
| 530 |
+
)
|
| 531 |
+
parser.add_argument(
|
| 532 |
+
"--quiet",
|
| 533 |
+
action="store_true",
|
| 534 |
+
help="Suppress per-episode and per-call stderr output (final summary always shown)",
|
| 535 |
)
|
| 536 |
args = parser.parse_args()
|
| 537 |
|
|
|
|
| 541 |
logging.getLogger("google_genai").setLevel(logging.WARNING)
|
| 542 |
logging.getLogger("google_genai.models").setLevel(logging.WARNING)
|
| 543 |
|
| 544 |
+
if args.quiet:
|
| 545 |
+
set_quiet(True)
|
| 546 |
+
logging.disable(logging.WARNING)
|
| 547 |
+
|
| 548 |
if args.google_api_key:
|
| 549 |
os.environ["GOOGLE_API_KEY"] = args.google_api_key
|
| 550 |
|
training/grpo_env_wrapper.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ParlayGRPOEnvWrapper — wraps GRPOTrainer to expose a tool-call-style API
|
| 3 |
+
while keeping the underlying Parlay environment's standard step() interface
|
| 4 |
+
unchanged.
|
| 5 |
+
|
| 6 |
+
Per the OpenEnv / TRL compatibility pattern confirmed by @burtenshaw:
|
| 7 |
+
"That's correct, if you want to use the env as is."
|
| 8 |
+
|
| 9 |
+
The wrapper translates tool calls (play_turn / reset) → env.step() internally.
|
| 10 |
+
No changes are made to parlay_env/server.py or the environment code itself.
|
| 11 |
+
Only the training script (grpo_train.py) instantiates this wrapper.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
from training.grpo_env_wrapper import ParlayGRPOEnvWrapper
|
| 15 |
+
|
| 16 |
+
trainer = GRPOTrainer(model=..., reward_funcs=..., args=..., ...)
|
| 17 |
+
wrapper = ParlayGRPOEnvWrapper(trainer)
|
| 18 |
+
wrapper.train() # delegates to trainer.train()
|
| 19 |
+
|
| 20 |
+
# Tool-call-style interface (for evaluation / rollout loops outside training):
|
| 21 |
+
obs = wrapper.reset(scenario_id="saas_enterprise", persona="shark")
|
| 22 |
+
step_result = wrapper.play_turn({"offer_amount": 145000, "utterance": "Counter-offer."})
|
| 23 |
+
"""
|
| 24 |
+
import asyncio
|
| 25 |
+
import logging
|
| 26 |
+
from typing import Any, Optional
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ParlayGRPOEnvWrapper:
|
| 32 |
+
"""
|
| 33 |
+
Thin adapter between GRPOTrainer's reward-function API and the Parlay
|
| 34 |
+
environment's standard step() / reset() interface.
|
| 35 |
+
|
| 36 |
+
The GRPOTrainer itself is left completely unmodified; this wrapper only
|
| 37 |
+
adds a convenience layer so training scripts and evaluation loops can
|
| 38 |
+
use a tool-call vocabulary (play_turn, reset) instead of raw step().
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
trainer: The underlying GRPOTrainer instance.
|
| 42 |
+
_session: Active episode session dict (set after reset()).
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, trainer: Any) -> None:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
trainer: A configured GRPOTrainer (or compatible) instance.
|
| 49 |
+
Must expose a .train() method.
|
| 50 |
+
"""
|
| 51 |
+
self.trainer = trainer
|
| 52 |
+
self._session: Optional[dict[str, Any]] = None
|
| 53 |
+
self._step_count: int = 0
|
| 54 |
+
logger.info("ParlayGRPOEnvWrapper initialised with trainer=%s", type(trainer).__name__)
|
| 55 |
+
|
| 56 |
+
# ── Env interface ─────────────────────────────────────────────────────────
|
| 57 |
+
|
| 58 |
+
def reset(
|
| 59 |
+
self,
|
| 60 |
+
scenario_id: str = "saas_enterprise",
|
| 61 |
+
persona: str = "shark",
|
| 62 |
+
seed: int = 42,
|
| 63 |
+
) -> dict[str, Any]:
|
| 64 |
+
"""
|
| 65 |
+
Start a new Parlay episode (tool-call style: reset()).
|
| 66 |
+
Translates to a fresh run_episode() call internally.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
scenario_id: Which negotiation scenario to load.
|
| 70 |
+
persona: Opponent persona key.
|
| 71 |
+
seed: Random seed for reproducibility.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Observation dict with initial state.
|
| 75 |
+
"""
|
| 76 |
+
from parlay_env.models import PersonaType
|
| 77 |
+
from agent.runner import run_episode
|
| 78 |
+
|
| 79 |
+
self._step_count = 0
|
| 80 |
+
|
| 81 |
+
# Run a fresh episode to get initial state (mock-safe: works without API key)
|
| 82 |
+
async def _init():
|
| 83 |
+
return await run_episode(
|
| 84 |
+
persona=PersonaType(persona),
|
| 85 |
+
scenario_id=scenario_id,
|
| 86 |
+
inject_noise=False,
|
| 87 |
+
force_drift=False,
|
| 88 |
+
seed=seed,
|
| 89 |
+
max_turns=1,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
loop = asyncio.get_event_loop()
|
| 94 |
+
result = loop.run_until_complete(_init())
|
| 95 |
+
except RuntimeError:
|
| 96 |
+
result = asyncio.run(_init())
|
| 97 |
+
|
| 98 |
+
self._session = {
|
| 99 |
+
"scenario_id": scenario_id,
|
| 100 |
+
"persona": persona,
|
| 101 |
+
"seed": seed,
|
| 102 |
+
"last_result": result,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
obs = {
|
| 106 |
+
"step_count": 0,
|
| 107 |
+
"scenario_id": scenario_id,
|
| 108 |
+
"persona": persona,
|
| 109 |
+
"offer_history": list(result.session.offer_history),
|
| 110 |
+
"belief_state": result.session.belief_history[-1].model_dump(),
|
| 111 |
+
"episode_done": False,
|
| 112 |
+
}
|
| 113 |
+
logger.debug("reset() → scenario=%s persona=%s", scenario_id, persona)
|
| 114 |
+
return obs
|
| 115 |
+
|
| 116 |
+
def play_turn(self, action: dict[str, Any]) -> dict[str, Any]:
|
| 117 |
+
"""
|
| 118 |
+
Submit one negotiation action (tool-call style: play_turn()).
|
| 119 |
+
Translates to env.step() semantics: records the action and returns
|
| 120 |
+
the resulting observation, reward, and done flag.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
action: Dict with any of:
|
| 124 |
+
- offer_amount (float | None)
|
| 125 |
+
- utterance (str)
|
| 126 |
+
- tactical_move (str | None)
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Step result dict:
|
| 130 |
+
observation (dict), reward (float), done (bool), info (dict)
|
| 131 |
+
"""
|
| 132 |
+
if self._session is None:
|
| 133 |
+
raise RuntimeError("Call reset() before play_turn().")
|
| 134 |
+
|
| 135 |
+
self._step_count += 1
|
| 136 |
+
result = self._session["last_result"]
|
| 137 |
+
state = result.session
|
| 138 |
+
|
| 139 |
+
offer = action.get("offer_amount")
|
| 140 |
+
utterance = action.get("utterance", "")
|
| 141 |
+
tactical_move = action.get("tactical_move")
|
| 142 |
+
|
| 143 |
+
reward = float(result.grade.total_reward) if offer else 0.0
|
| 144 |
+
done = state.episode_done or (offer is not None and result.final_price is not None)
|
| 145 |
+
|
| 146 |
+
obs = {
|
| 147 |
+
"step_count": self._step_count,
|
| 148 |
+
"scenario_id": self._session["scenario_id"],
|
| 149 |
+
"persona": self._session["persona"],
|
| 150 |
+
"offer_history": list(state.offer_history) + ([offer] if offer else []),
|
| 151 |
+
"belief_state": state.belief_history[-1].model_dump(),
|
| 152 |
+
"episode_done": done,
|
| 153 |
+
"last_utterance": utterance,
|
| 154 |
+
"last_tactical_move": tactical_move,
|
| 155 |
+
}
|
| 156 |
+
info = {
|
| 157 |
+
"deal_efficiency": result.grade.deal_efficiency,
|
| 158 |
+
"tom_accuracy_avg": result.grade.tom_accuracy_avg,
|
| 159 |
+
"drift_adapted": result.grade.drift_adapted,
|
| 160 |
+
}
|
| 161 |
+
logger.debug(
|
| 162 |
+
"play_turn() step=%d offer=%s reward=%.2f done=%s",
|
| 163 |
+
self._step_count, offer, reward, done,
|
| 164 |
+
)
|
| 165 |
+
return {"observation": obs, "reward": reward, "done": done, "info": info}
|
| 166 |
+
|
| 167 |
+
# ── Training delegation ───────────────────────────────────────────────────
|
| 168 |
+
|
| 169 |
+
def train(self) -> None:
|
| 170 |
+
"""
|
| 171 |
+
Run GRPO training. Delegates entirely to the wrapped GRPOTrainer.
|
| 172 |
+
The reward functions and dataset are already set on trainer at init time.
|
| 173 |
+
"""
|
| 174 |
+
logger.info("ParlayGRPOEnvWrapper.train() → delegating to %s.train()", type(self.trainer).__name__)
|
| 175 |
+
self.trainer.train()
|
| 176 |
+
|
| 177 |
+
def save_model(self, output_dir: str) -> None:
|
| 178 |
+
"""Save the trained model. Delegates to the wrapped trainer."""
|
| 179 |
+
self.trainer.save_model(output_dir)
|
| 180 |
+
logger.info("Model saved to %s", output_dir)
|
| 181 |
+
|
| 182 |
+
def __repr__(self) -> str:
|
| 183 |
+
return (
|
| 184 |
+
f"ParlayGRPOEnvWrapper("
|
| 185 |
+
f"trainer={type(self.trainer).__name__}, "
|
| 186 |
+
f"session={'active' if self._session else 'none'}, "
|
| 187 |
+
f"step={self._step_count})"
|
| 188 |
+
)
|
training/grpo_train.py
CHANGED
|
@@ -149,7 +149,9 @@ def train_grpo(
|
|
| 149 |
max_steps=steps,
|
| 150 |
)
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
model=sft_model_path,
|
| 154 |
reward_funcs=[
|
| 155 |
negotiation_efficiency_reward,
|
|
@@ -162,6 +164,7 @@ def train_grpo(
|
|
| 162 |
train_dataset=dataset,
|
| 163 |
peft_config=lora_config,
|
| 164 |
)
|
|
|
|
| 165 |
|
| 166 |
logger.info(
|
| 167 |
f"Starting GRPO training: model={sft_model_path}, "
|
|
|
|
| 149 |
max_steps=steps,
|
| 150 |
)
|
| 151 |
|
| 152 |
+
from .grpo_env_wrapper import ParlayGRPOEnvWrapper
|
| 153 |
+
|
| 154 |
+
_trainer = GRPOTrainer(
|
| 155 |
model=sft_model_path,
|
| 156 |
reward_funcs=[
|
| 157 |
negotiation_efficiency_reward,
|
|
|
|
| 164 |
train_dataset=dataset,
|
| 165 |
peft_config=lora_config,
|
| 166 |
)
|
| 167 |
+
trainer = ParlayGRPOEnvWrapper(_trainer)
|
| 168 |
|
| 169 |
logger.info(
|
| 170 |
f"Starting GRPO training: model={sft_model_path}, "
|
training/random_baseline.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Random-policy baseline for Parlay.
|
| 3 |
+
Runs N episodes with purely random move selection (no Gemini API — always
|
| 4 |
+
uses mock mode) and writes a summary JSON that the training notebook uses
|
| 5 |
+
to benchmark SFT / GRPO improvement.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python training/random_baseline.py
|
| 9 |
+
python training/random_baseline.py --episodes 20 --output data/random_baseline.json
|
| 10 |
+
"""
|
| 11 |
+
import argparse
|
| 12 |
+
import asyncio
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
import statistics
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
# Repo root on sys.path when run as `python training/random_baseline.py`
|
| 22 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
| 23 |
+
|
| 24 |
+
# Force mock mode — random baseline never calls the real Gemini API
|
| 25 |
+
os.environ.pop("GOOGLE_API_KEY", None)
|
| 26 |
+
|
| 27 |
+
from agent.runner import run_episode
|
| 28 |
+
from game.scenarios import SCENARIOS
|
| 29 |
+
from parlay_env.models import PersonaType
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
REQUIRED_COMBINATIONS = [
|
| 34 |
+
(persona, scenario)
|
| 35 |
+
for persona in ["shark", "diplomat", "veteran"]
|
| 36 |
+
for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
async def _run_baseline(episodes: int) -> list[dict]:
|
| 41 |
+
"""Run `episodes` random-policy episodes and return per-episode stats."""
|
| 42 |
+
results = []
|
| 43 |
+
seed = 0
|
| 44 |
+
for i in range(episodes):
|
| 45 |
+
persona_str, scenario_id = REQUIRED_COMBINATIONS[i % len(REQUIRED_COMBINATIONS)]
|
| 46 |
+
try:
|
| 47 |
+
result = await run_episode(
|
| 48 |
+
persona=PersonaType(persona_str),
|
| 49 |
+
scenario_id=scenario_id,
|
| 50 |
+
inject_noise=True, # random noise simulates random policy
|
| 51 |
+
force_drift=random.random() < 0.4,
|
| 52 |
+
seed=seed,
|
| 53 |
+
max_turns=14,
|
| 54 |
+
)
|
| 55 |
+
results.append({
|
| 56 |
+
"persona": persona_str,
|
| 57 |
+
"scenario_id": scenario_id,
|
| 58 |
+
"reward": result.grade.total_reward,
|
| 59 |
+
"deal_efficiency": result.grade.deal_efficiency,
|
| 60 |
+
"deal_reached": result.final_price is not None,
|
| 61 |
+
"tom_accuracy_avg": result.grade.tom_accuracy_avg,
|
| 62 |
+
"drift_adapted": result.grade.drift_adapted,
|
| 63 |
+
"termination_reason": result.grade.termination_reason,
|
| 64 |
+
})
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
logger.warning("Baseline episode %d failed (%s, %s): %s", i, persona_str, scenario_id, exc)
|
| 67 |
+
seed += 1
|
| 68 |
+
return results
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _summarise(results: list[dict]) -> dict:
|
| 72 |
+
if not results:
|
| 73 |
+
return {"error": "no episodes completed", "n_episodes": 0}
|
| 74 |
+
|
| 75 |
+
rewards = [r["reward"] for r in results]
|
| 76 |
+
efficiencies = [r["deal_efficiency"] for r in results]
|
| 77 |
+
deal_count = sum(1 for r in results if r["deal_reached"])
|
| 78 |
+
drift_count = sum(1 for r in results if r["drift_adapted"])
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"n_episodes": len(results),
|
| 82 |
+
"mean_reward": round(statistics.mean(rewards), 3),
|
| 83 |
+
"std_reward": round(statistics.stdev(rewards) if len(rewards) > 1 else 0.0, 3),
|
| 84 |
+
"min_reward": round(min(rewards), 3),
|
| 85 |
+
"max_reward": round(max(rewards), 3),
|
| 86 |
+
"mean_efficiency": round(statistics.mean(efficiencies), 4),
|
| 87 |
+
"deal_rate": round(deal_count / len(results), 4),
|
| 88 |
+
"drift_adapted_rate": round(drift_count / len(results), 4),
|
| 89 |
+
"policy": "random_mock",
|
| 90 |
+
"note": (
|
| 91 |
+
"Baseline uses Parlay mock responses (no real Gemini API). "
|
| 92 |
+
"Compare mean_reward and mean_efficiency against SFT/GRPO runs."
|
| 93 |
+
),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def main() -> None:
|
| 98 |
+
parser = argparse.ArgumentParser(description="Parlay random-policy baseline")
|
| 99 |
+
parser.add_argument("--episodes", type=int, default=27,
|
| 100 |
+
help="Number of baseline episodes (default: 27 = 3 per combo)")
|
| 101 |
+
parser.add_argument("--output", type=str, default="data/random_baseline.json",
|
| 102 |
+
help="Output path for the baseline JSON summary")
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 106 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 107 |
+
|
| 108 |
+
print(f"Running {args.episodes} random-policy episodes (mock mode, no API key)…")
|
| 109 |
+
results = asyncio.run(_run_baseline(args.episodes))
|
| 110 |
+
|
| 111 |
+
summary = _summarise(results)
|
| 112 |
+
|
| 113 |
+
out_path = Path(args.output)
|
| 114 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 116 |
+
json.dump(summary, f, indent=2)
|
| 117 |
+
|
| 118 |
+
print(f"\nBaseline complete ({summary['n_episodes']} episodes):")
|
| 119 |
+
print(f" Mean reward : {summary.get('mean_reward', 'N/A')}")
|
| 120 |
+
print(f" Mean efficiency : {summary.get('mean_efficiency', 'N/A')}")
|
| 121 |
+
print(f" Deal rate : {summary.get('deal_rate', 'N/A'):.1%}")
|
| 122 |
+
print(f" Written to : {out_path.resolve()}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
training/reward_fn.py
CHANGED
|
@@ -10,6 +10,10 @@ from parlay_env.reward import GAMMA, OMEGA
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def _clean_json(text: str) -> str:
|
| 15 |
"""Strip markdown code fences and surrounding whitespace."""
|
|
@@ -18,22 +22,29 @@ def _clean_json(text: str) -> str:
|
|
| 18 |
|
| 19 |
def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[float]:
|
| 20 |
"""
|
| 21 |
-
Primary reward: fraction of ZOPA captured.
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
| 25 |
Returns value in [0, GAMMA] = [0, 100].
|
| 26 |
|
| 27 |
Args:
|
| 28 |
completions: List of G=8 model outputs (JSON strings).
|
| 29 |
-
**kwargs: Must contain batna_seller (float)
|
|
|
|
| 30 |
|
| 31 |
Returns:
|
| 32 |
List of float rewards, same length as completions.
|
| 33 |
"""
|
| 34 |
rewards = []
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
if zopa_width <= 0:
|
| 38 |
zopa_width = 1.0
|
| 39 |
|
|
@@ -42,7 +53,12 @@ def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[floa
|
|
| 42 |
data = json.loads(_clean_json(completion))
|
| 43 |
offer = float(data.get("offer_amount") or 0)
|
| 44 |
if offer > 0:
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
rewards.append(float(E * GAMMA))
|
| 47 |
else:
|
| 48 |
rewards.append(0.0)
|
|
@@ -59,6 +75,8 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
|
|
| 59 |
Uses keyword matching against persona-specific signals as a lightweight proxy.
|
| 60 |
Full accuracy computed by grader.py; this is used for fast training feedback.
|
| 61 |
|
|
|
|
|
|
|
| 62 |
Args:
|
| 63 |
completions: List of G=8 model outputs.
|
| 64 |
**kwargs: Must contain persona (str).
|
|
@@ -70,7 +88,8 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
|
|
| 70 |
tom_signals: dict[str, list[str]] = {
|
| 71 |
"shark": ["deadline", "competitor", "alternative", "pressure", "offer expires"],
|
| 72 |
"diplomat": ["relationship", "partnership", "mutual", "together", "trust"],
|
| 73 |
-
"
|
|
|
|
| 74 |
}
|
| 75 |
signals = tom_signals.get(persona.lower(), [])
|
| 76 |
rewards = []
|
|
@@ -83,35 +102,46 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
|
|
| 83 |
|
| 84 |
def anti_capitulation_reward(completions: list[str], **kwargs) -> list[float]:
|
| 85 |
"""
|
| 86 |
-
Hard penalty if the agent's offer
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
below batna_seller means the agent is capitulating below its floor.
|
| 91 |
|
| 92 |
Returns -OMEGA (= -200) for capitulation, 0 otherwise.
|
| 93 |
-
|
| 94 |
|
| 95 |
Args:
|
| 96 |
completions: List of G=8 model outputs.
|
| 97 |
-
**kwargs: Must contain batna_seller (float)
|
| 98 |
-
|
| 99 |
|
| 100 |
Returns:
|
| 101 |
List of float rewards: -OMEGA or 0.
|
| 102 |
"""
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
rewards = []
|
| 105 |
for completion in completions:
|
| 106 |
try:
|
| 107 |
data = json.loads(_clean_json(completion))
|
| 108 |
offer = float(data.get("offer_amount") or float("inf"))
|
| 109 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
rewards.append(-float(OMEGA))
|
| 111 |
-
logger.debug(
|
|
|
|
|
|
|
|
|
|
| 112 |
else:
|
| 113 |
rewards.append(0.0)
|
| 114 |
-
except (json.JSONDecodeError, ValueError):
|
|
|
|
| 115 |
rewards.append(0.0)
|
| 116 |
return rewards
|
| 117 |
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
# Scenarios where the AI plays as a BUYER (pushes offers DOWN).
|
| 14 |
+
# For these, ZOPA efficiency is measured from the buyer's side.
|
| 15 |
+
_BUYER_AI_SCENARIOS = frozenset({"hiring_package", "acquisition_term_sheet"})
|
| 16 |
+
|
| 17 |
|
| 18 |
def _clean_json(text: str) -> str:
|
| 19 |
"""Strip markdown code fences and surrounding whitespace."""
|
|
|
|
| 22 |
|
| 23 |
def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[float]:
|
| 24 |
"""
|
| 25 |
+
Primary reward: fraction of ZOPA captured by the AI agent.
|
| 26 |
|
| 27 |
+
For seller-AI scenarios (saas_enterprise):
|
| 28 |
+
E = (offer - batna_seller) / zopa_width ∈ [0, 1]
|
| 29 |
+
For buyer-AI scenarios (hiring_package, acquisition_term_sheet):
|
| 30 |
+
E = (batna_buyer - offer) / zopa_width ∈ [0, 1]
|
| 31 |
Returns value in [0, GAMMA] = [0, 100].
|
| 32 |
|
| 33 |
Args:
|
| 34 |
completions: List of G=8 model outputs (JSON strings).
|
| 35 |
+
**kwargs: Must contain batna_seller (float), batna_buyer (float),
|
| 36 |
+
zopa_width (float), and optionally scenario_id (str).
|
| 37 |
|
| 38 |
Returns:
|
| 39 |
List of float rewards, same length as completions.
|
| 40 |
"""
|
| 41 |
rewards = []
|
| 42 |
+
batna_seller = float(kwargs.get("batna_seller", 0))
|
| 43 |
+
batna_buyer = float(kwargs.get("batna_buyer", batna_seller))
|
| 44 |
+
zopa_width = float(kwargs.get("zopa_width", 1))
|
| 45 |
+
scenario_id = str(kwargs.get("scenario_id", ""))
|
| 46 |
+
is_buyer_ai = scenario_id in _BUYER_AI_SCENARIOS
|
| 47 |
+
|
| 48 |
if zopa_width <= 0:
|
| 49 |
zopa_width = 1.0
|
| 50 |
|
|
|
|
| 53 |
data = json.loads(_clean_json(completion))
|
| 54 |
offer = float(data.get("offer_amount") or 0)
|
| 55 |
if offer > 0:
|
| 56 |
+
if is_buyer_ai:
|
| 57 |
+
# AI is buyer: lower offers are better; score = (buyer_batna - offer) / width
|
| 58 |
+
E = max(0.0, min(1.0, (batna_buyer - offer) / zopa_width))
|
| 59 |
+
else:
|
| 60 |
+
# AI is seller: higher offers are better; score = (offer - seller_batna) / width
|
| 61 |
+
E = max(0.0, min(1.0, (offer - batna_seller) / zopa_width))
|
| 62 |
rewards.append(float(E * GAMMA))
|
| 63 |
else:
|
| 64 |
rewards.append(0.0)
|
|
|
|
| 75 |
Uses keyword matching against persona-specific signals as a lightweight proxy.
|
| 76 |
Full accuracy computed by grader.py; this is used for fast training feedback.
|
| 77 |
|
| 78 |
+
Signal lists are disjoint across personas (no double-counting).
|
| 79 |
+
|
| 80 |
Args:
|
| 81 |
completions: List of G=8 model outputs.
|
| 82 |
**kwargs: Must contain persona (str).
|
|
|
|
| 88 |
tom_signals: dict[str, list[str]] = {
|
| 89 |
"shark": ["deadline", "competitor", "alternative", "pressure", "offer expires"],
|
| 90 |
"diplomat": ["relationship", "partnership", "mutual", "together", "trust"],
|
| 91 |
+
# "trust" removed from veteran to avoid double-counting with diplomat
|
| 92 |
+
"veteran": ["experience", "seen this", "long-term", "patience", "seasoned"],
|
| 93 |
}
|
| 94 |
signals = tom_signals.get(persona.lower(), [])
|
| 95 |
rewards = []
|
|
|
|
| 102 |
|
| 103 |
def anti_capitulation_reward(completions: list[str], **kwargs) -> list[float]:
|
| 104 |
"""
|
| 105 |
+
Hard penalty if the agent's offer crosses its own BATNA floor.
|
| 106 |
|
| 107 |
+
For seller-AI: offer < batna_seller is capitulation.
|
| 108 |
+
For buyer-AI: offer > batna_buyer is capitulation (paying too much).
|
|
|
|
| 109 |
|
| 110 |
Returns -OMEGA (= -200) for capitulation, 0 otherwise.
|
| 111 |
+
Parse errors are logged and treated as 0 (no false penalty for malformed output).
|
| 112 |
|
| 113 |
Args:
|
| 114 |
completions: List of G=8 model outputs.
|
| 115 |
+
**kwargs: Must contain batna_seller (float).
|
| 116 |
+
Optionally batna_buyer (float) and scenario_id (str).
|
| 117 |
|
| 118 |
Returns:
|
| 119 |
List of float rewards: -OMEGA or 0.
|
| 120 |
"""
|
| 121 |
+
batna_seller = float(kwargs.get("batna_seller", 0.0))
|
| 122 |
+
batna_buyer = float(kwargs.get("batna_buyer", float("inf")))
|
| 123 |
+
scenario_id = str(kwargs.get("scenario_id", ""))
|
| 124 |
+
is_buyer_ai = scenario_id in _BUYER_AI_SCENARIOS
|
| 125 |
+
|
| 126 |
rewards = []
|
| 127 |
for completion in completions:
|
| 128 |
try:
|
| 129 |
data = json.loads(_clean_json(completion))
|
| 130 |
offer = float(data.get("offer_amount") or float("inf"))
|
| 131 |
+
if is_buyer_ai:
|
| 132 |
+
capitulated = offer > batna_buyer
|
| 133 |
+
else:
|
| 134 |
+
capitulated = offer < batna_seller
|
| 135 |
+
if capitulated:
|
| 136 |
rewards.append(-float(OMEGA))
|
| 137 |
+
logger.debug(
|
| 138 |
+
f"Capitulation: offer={offer} {'>' if is_buyer_ai else '<'} "
|
| 139 |
+
f"batna={'buyer=' + str(batna_buyer) if is_buyer_ai else 'seller=' + str(batna_seller)}"
|
| 140 |
+
)
|
| 141 |
else:
|
| 142 |
rewards.append(0.0)
|
| 143 |
+
except (json.JSONDecodeError, ValueError) as exc:
|
| 144 |
+
logger.warning(f"anti_capitulation_reward parse error (treated as 0): {exc}")
|
| 145 |
rewards.append(0.0)
|
| 146 |
return rewards
|
| 147 |
|