sh4shv4t commited on
Commit
15976d0
·
1 Parent(s): f3d2cd4

feat: flash-lite for data-gen and flash for UI; remove training page; card tests; --quiet data gen; data/ inspect path; random baseline; GRPO env wrapper; reward fixes (buyer ZOPA, ToM signals); drift + Brier metrics; Bayesian ToM module

Browse files
agent/gemini_client.py CHANGED
@@ -5,9 +5,11 @@ All errors return SYNTHETIC_RESPONSE.
5
  When GOOGLE_API_KEY is absent, MOCK_RESPONSES are returned so the full game
6
  loop works without any API key.
7
 
8
- Primary model for API calls (dashboard, MCP, data generation, ToM):
9
- - GEMINI_MODEL gemini-2.5-flash-lite for bulk/self-play and live callers
10
- that pass this id via MODEL_ID_DEMO / MODEL_ID_DATA aliases.
 
 
11
  """
12
  import asyncio
13
  import json
@@ -54,15 +56,21 @@ SCENARIO_ROLE_CONTEXT: dict[str, dict[str, str]] = {
54
  },
55
  }
56
 
57
- GEMINI_MODEL = "gemini-2.5-flash-lite"
58
- # Aliases for imports (dashboard, MCP, training all use flash-lite)
59
- MODEL_ID_DEMO = GEMINI_MODEL
60
- MODEL_ID_DATA = GEMINI_MODEL
61
- MODEL_ID = GEMINI_MODEL
62
 
63
  _client = None
64
  _mock_warned: bool = False
65
  _gemini_model_logged: bool = False
 
 
 
 
 
 
 
66
 
67
  # ── Mock responses (keyless dev / CI) ────────────────────────────────────────
68
  # Offer amounts are realistic for the default SaaS enterprise scenario
@@ -336,10 +344,11 @@ async def call_gemini(
336
 
337
  _turn_count += 1
338
  _live_calls += 1
339
- print(
340
- f"[Gemini LIVE] model={mid} chars={len(response.text or '')} turn={_turn_count}",
341
- file=sys.stderr,
342
- )
 
343
 
344
  text = (response.text or "").strip()
345
  text = text.replace("```json", "").replace("```", "").strip()
 
5
  When GOOGLE_API_KEY is absent, MOCK_RESPONSES are returned so the full game
6
  loop works without any API key.
7
 
8
+ Model routing:
9
+ - MODEL_ID_DATA (gemini-2.5-flash-lite) data generation, self-play, ToM inference.
10
+ Low-latency, high-throughput; used by runner.py and generate_data.py.
11
+ - MODEL_ID_DEMO (gemini-2.5-flash) — web UI, dashboard API, MCP tools.
12
+ Higher quality responses for live user interaction.
13
  """
14
  import asyncio
15
  import json
 
56
  },
57
  }
58
 
59
+ GEMINI_MODEL = "gemini-2.5-flash-lite" # kept for backward compat; equals MODEL_ID_DATA
60
+ MODEL_ID_DATA = "gemini-2.5-flash-lite" # data generation, self-play, ToM inference
61
+ MODEL_ID_DEMO = "gemini-2.5-flash" # web UI, dashboard API, MCP tools
62
+ MODEL_ID = MODEL_ID_DATA # stable alias (runner.py omits model= → flash-lite)
 
63
 
64
  _client = None
65
  _mock_warned: bool = False
66
  _gemini_model_logged: bool = False
67
+ _quiet: bool = False # suppresses per-call [Gemini LIVE] prints when True
68
+
69
+
70
+ def set_quiet(flag: bool) -> None:
71
+ """Suppress [Gemini LIVE] per-call stderr prints (e.g. during test runs)."""
72
+ global _quiet
73
+ _quiet = flag
74
 
75
  # ── Mock responses (keyless dev / CI) ────────────────────────────────────────
76
  # Offer amounts are realistic for the default SaaS enterprise scenario
 
344
 
345
  _turn_count += 1
346
  _live_calls += 1
347
+ if not _quiet:
348
+ print(
349
+ f"[Gemini LIVE] model={mid} chars={len(response.text or '')} turn={_turn_count}",
350
+ file=sys.stderr,
351
+ )
352
 
353
  text = (response.text or "").strip()
354
  text = text.replace("```json", "").replace("```", "").strip()
agent/runner.py CHANGED
@@ -131,8 +131,12 @@ async def run_episode(
131
  for event in scenario.drift_events:
132
  if event.trigger_turn == turn or (forced_drift_turn == turn):
133
  drift_turn = turn
134
- tom.drift_event(event.effect_on_urgency, event.effect_on_has_alternative)
135
- logger.info(f"Drift event at turn {turn}: {event.event}")
 
 
 
 
136
  break
137
 
138
  if inject_noise and turn < 3 and rng.random() < 0.3:
@@ -211,8 +215,16 @@ async def run_episode(
211
 
212
  if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2:
213
  adaptation_signals = ["understand", "noted", "given that", "considering"]
214
- if any(s in action.utterance.lower() for s in adaptation_signals):
 
 
 
215
  drift_adapted = True
 
 
 
 
 
216
 
217
  new_offers = list(state.offer_history)
218
  if action.offer_amount:
 
131
  for event in scenario.drift_events:
132
  if event.trigger_turn == turn or (forced_drift_turn == turn):
133
  drift_turn = turn
134
+ tom.drift_event(
135
+ event.effect_on_urgency,
136
+ event.effect_on_has_alternative,
137
+ event_description=event.event,
138
+ )
139
+ logger.info(f"Drift event at turn {turn}: {event.event!r}")
140
  break
141
 
142
  if inject_noise and turn < 3 and rng.random() < 0.3:
 
215
 
216
  if drift_turn is not None and not drift_adapted and turn <= drift_turn + 2:
217
  adaptation_signals = ["understand", "noted", "given that", "considering"]
218
+ matched = next(
219
+ (s for s in adaptation_signals if s in action.utterance.lower()), None
220
+ )
221
+ if matched:
222
  drift_adapted = True
223
+ logger.info(
224
+ f"drift_adapted=True at turn={turn} "
225
+ f"matched_phrase={matched!r} "
226
+ f"utterance_snippet={action.utterance[:80]!r}"
227
+ )
228
 
229
  new_offers = list(state.offer_history)
230
  if action.offer_amount:
agent/tom_tracker.py CHANGED
@@ -129,6 +129,7 @@ class ToMTracker:
129
  self,
130
  effect_on_urgency: float,
131
  effect_on_has_alternative: bool,
 
132
  ) -> BeliefState:
133
  """
134
  Apply a drift event to beliefs.
@@ -136,6 +137,8 @@ class ToMTracker:
136
  Args:
137
  effect_on_urgency: Signed delta to urgency estimate.
138
  effect_on_has_alternative: Override for has_alternative belief.
 
 
139
 
140
  Returns:
141
  Updated BeliefState post-drift.
@@ -150,12 +153,49 @@ class ToMTracker:
150
  confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
151
  )
152
  self.history.append(updated)
 
153
  logger.info(
154
- f"ToM drift applied: urgency={new_urgency:.2f}, "
 
155
  f"alt={effect_on_has_alternative}"
156
  )
157
  return updated
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def accuracy_against(self, hidden: HiddenState) -> float:
160
  """
161
  Compute current belief accuracy against true hidden state.
 
129
  self,
130
  effect_on_urgency: float,
131
  effect_on_has_alternative: bool,
132
+ event_description: str = "",
133
  ) -> BeliefState:
134
  """
135
  Apply a drift event to beliefs.
 
137
  Args:
138
  effect_on_urgency: Signed delta to urgency estimate.
139
  effect_on_has_alternative: Override for has_alternative belief.
140
+ event_description: Human-readable scenario event string
141
+ (e.g. "Competitor drops price 15%").
142
 
143
  Returns:
144
  Updated BeliefState post-drift.
 
153
  confidence=max(0.0, last.confidence - 0.15), # drift reduces confidence
154
  )
155
  self.history.append(updated)
156
+ desc_part = f" | event={event_description!r}" if event_description else ""
157
  logger.info(
158
+ f"ToM drift applied{desc_part}: "
159
+ f"urgency_delta={effect_on_urgency:+.2f} → {new_urgency:.2f}, "
160
  f"alt={effect_on_has_alternative}"
161
  )
162
  return updated
163
 
164
+ def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
165
+ """
166
+ Compute per-field Brier scores over the full belief history.
167
+
168
+ Brier score = (1/N) Σ (predicted - actual)²
169
+ Lower is better; 0 = perfect.
170
+
171
+ Fields scored:
172
+ - urgency: est_urgency (continuous 0–1) vs hidden.urgency_score
173
+ - has_alt: est_has_alternative (0/1 probability) vs hidden.has_alternative
174
+
175
+ Args:
176
+ hidden: The true hidden state revealed at episode end.
177
+
178
+ Returns:
179
+ Dict with keys "urgency" and "has_alt", each a float in [0, 1].
180
+ """
181
+ if not self.history:
182
+ return {"urgency": 1.0, "has_alt": 1.0}
183
+
184
+ actual_urgency = hidden.urgency_score
185
+ actual_alt = float(hidden.has_alternative)
186
+
187
+ urgency_sq_err = sum(
188
+ (b.est_urgency - actual_urgency) ** 2 for b in self.history
189
+ )
190
+ alt_sq_err = sum(
191
+ (float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history
192
+ )
193
+ n = len(self.history)
194
+ return {
195
+ "urgency": round(urgency_sq_err / n, 6),
196
+ "has_alt": round(alt_sq_err / n, 6),
197
+ }
198
+
199
  def accuracy_against(self, hidden: HiddenState) -> float:
200
  """
201
  Compute current belief accuracy against true hidden state.
agent/tom_tracker_bayesian.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bayesian Theory-of-Mind belief tracker for Parlay.
3
+
4
+ Drop-in replacement for ToMTracker that uses Kalman-filter-style Gaussian
5
+ belief updates instead of hand-tuned arithmetic nudges.
6
+
7
+ Key insight
8
+ -----------
9
+ The opponent has hidden variables (budget_ceiling, walk_away_price, urgency,
10
+ has_alternative). Each observed offer is a noisy signal about these.
11
+ We model each continuous variable as a Gaussian (mean, variance) and update
12
+ using the standard Bayesian update for Gaussian conjugate priors:
13
+
14
+ posterior_mean = (prior_mean / prior_var + obs / obs_var) /
15
+ (1 / prior_var + 1 / obs_var)
16
+ posterior_var = 1 / (1 / prior_var + 1 / obs_var)
17
+
18
+ `confidence` is derived from the posterior variance:
19
+ confidence = 1 / (1 + sqrt(budget_var / budget_mean²))
20
+
21
+ Usage (as feature-flag alternative to ToMTracker):
22
+ from agent.tom_tracker_bayesian import BayesianToMTracker as ToMTracker
23
+
24
+ # Then use exactly the same API as ToMTracker — all method signatures match.
25
+ """
26
+ import logging
27
+ import math
28
+ import sys
29
+ from typing import Optional
30
+
31
+ from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMove
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class BayesianToMTracker:
37
+ """
38
+ Gaussian-posterior belief tracker for the opponent's hidden state.
39
+
40
+ Extends the original ToMTracker API with proper Bayesian updating.
41
+ The same public methods (update, drift_event, accuracy_against,
42
+ brier_scores, log_belief_snapshot) are preserved for drop-in use.
43
+
44
+ Internal state:
45
+ _budget_mean, _budget_var — Gaussian over opponent's budget ceiling.
46
+ _walk_mean, _walk_var — Gaussian over opponent's walk-away price.
47
+ _urgency_mean, _urgency_var — Gaussian over urgency [0, 1].
48
+ _alt_prob — Bernoulli probability of has_alternative.
49
+ """
50
+
51
+ # Observation noise variances (tuned for B2B negotiation scale).
52
+ # Budget/walk-away: observed offer is a noisy signal; high variance because
53
+ # opponents rarely reveal their true limits.
54
+ _OBS_BUDGET_VAR_FRAC = 0.10 # 10% of current mean estimate as std
55
+ _OBS_URGENCY_VAR = 0.05 # small update per offer-ratio signal
56
+
57
+ def __init__(
58
+ self,
59
+ initial_belief: BeliefState,
60
+ persona: PersonaType,
61
+ ) -> None:
62
+ """
63
+ Args:
64
+ initial_belief: Starting BeliefState (imprecise prior).
65
+ persona: Opponent persona (known to the player).
66
+ """
67
+ self.persona = persona
68
+ self._bluffs_detected: int = 0
69
+
70
+ # Initialise Gaussian priors from the initial belief
71
+ self._budget_mean = float(initial_belief.est_budget)
72
+ self._walk_mean = float(initial_belief.est_walk_away)
73
+ self._urgency_mean = float(initial_belief.est_urgency)
74
+ self._alt_prob = 0.3 # prior: 30% chance opponent has an alternative
75
+
76
+ # Initial variances — large uncertainty at the start
77
+ self._budget_var = (self._budget_mean * 0.30) ** 2 # ±30% std
78
+ self._walk_var = (self._walk_mean * 0.30) ** 2
79
+ self._urgency_var = 0.08 # std ≈ 0.28 over [0, 1]
80
+
81
+ self.history: list[BeliefState] = [self._snapshot()]
82
+ logger.debug(
83
+ "BayesianToMTracker init: budget_mean=%.0f walk_mean=%.0f urgency_mean=%.2f",
84
+ self._budget_mean, self._walk_mean, self._urgency_mean,
85
+ )
86
+
87
+ # ── Internal helpers ──────────────────────────────────────────────────────
88
+
89
+ def _snapshot(self) -> BeliefState:
90
+ """Convert current Gaussian state to a BeliefState snapshot."""
91
+ confidence = self._compute_confidence()
92
+ return BeliefState(
93
+ est_budget=round(self._budget_mean, 2),
94
+ est_walk_away=round(self._walk_mean, 2),
95
+ est_urgency=round(max(0.0, min(1.0, self._urgency_mean)), 4),
96
+ est_has_alternative=self._alt_prob >= 0.5,
97
+ confidence=round(confidence, 4),
98
+ )
99
+
100
+ def _compute_confidence(self) -> float:
101
+ """
102
+ Confidence = 1 - mean relative std across all variables.
103
+ Shrinks variance → higher confidence.
104
+ """
105
+ budget_rel_std = math.sqrt(self._budget_var) / max(abs(self._budget_mean), 1.0)
106
+ walk_rel_std = math.sqrt(self._walk_var) / max(abs(self._walk_mean), 1.0)
107
+ urgency_std = math.sqrt(self._urgency_var)
108
+ alt_std = math.sqrt(self._alt_prob * (1.0 - self._alt_prob))
109
+ mean_uncertainty = (budget_rel_std + walk_rel_std + urgency_std + alt_std) / 4.0
110
+ return max(0.0, min(1.0, 1.0 - mean_uncertainty))
111
+
112
+ @staticmethod
113
+ def _gaussian_update(
114
+ prior_mean: float,
115
+ prior_var: float,
116
+ obs: float,
117
+ obs_var: float,
118
+ ) -> tuple[float, float]:
119
+ """
120
+ Closed-form Bayesian update for Gaussian conjugate prior.
121
+
122
+ posterior_mean = (prior_mean / prior_var + obs / obs_var) /
123
+ (1 / prior_var + 1 / obs_var)
124
+ posterior_var = 1 / (1 / prior_var + 1 / obs_var)
125
+ """
126
+ prec_prior = 1.0 / max(prior_var, 1e-10)
127
+ prec_obs = 1.0 / max(obs_var, 1e-10)
128
+ posterior_prec = prec_prior + prec_obs
129
+ posterior_mean = (prec_prior * prior_mean + prec_obs * obs) / posterior_prec
130
+ posterior_var = 1.0 / posterior_prec
131
+ return posterior_mean, posterior_var
132
+
133
+ # ── Public API (matches ToMTracker) ──────────────────────────────────────
134
+
135
+ @property
136
+ def current_belief(self) -> BeliefState:
137
+ return self.history[-1]
138
+
139
+ @property
140
+ def bluffs_detected(self) -> int:
141
+ return self._bluffs_detected
142
+
143
+ def log_belief_snapshot(self, turn: int) -> None:
144
+ b = self.current_belief
145
+ print(
146
+ f"[BayesToM turn={turn}] "
147
+ f"budget={b.est_budget:.0f}±{math.sqrt(self._budget_var):.0f} "
148
+ f"urgency={b.est_urgency:.3f}±{math.sqrt(self._urgency_var):.3f} "
149
+ f"alt_prob={self._alt_prob:.2f} conf={b.confidence:.2f}",
150
+ file=sys.stderr,
151
+ )
152
+
153
+ def update(
154
+ self,
155
+ observed_offer: Optional[float],
156
+ observed_move: Optional[TacticalMove],
157
+ utterance: str,
158
+ turn: int,
159
+ ) -> BeliefState:
160
+ """
161
+ Bayesian update of all beliefs from one observed opponent action.
162
+
163
+ Budget update: if we see an offer O, the true budget is likely > O.
164
+ We treat O as a lower-bound signal: observation = O * 1.05
165
+ with variance proportional to the current mean.
166
+ Urgency update: offer-ratio below 0.85 → urgency signal 0.7;
167
+ above 0.95 → urgency signal 0.3. Both with moderate obs variance.
168
+ has_alternative: updated as Bernoulli likelihood ratio (keyword match).
169
+ """
170
+ # ── Budget Bayesian update ──────────────────────────────────────────
171
+ if observed_offer is not None and observed_offer > 0:
172
+ budget_obs = observed_offer * 1.05
173
+ obs_budget_var = (self._budget_mean * self._OBS_BUDGET_VAR_FRAC) ** 2
174
+ self._budget_mean, self._budget_var = self._gaussian_update(
175
+ self._budget_mean, self._budget_var,
176
+ budget_obs, obs_budget_var,
177
+ )
178
+ logger.debug(
179
+ "Bayesian budget update: obs=%.0f → mean=%.0f std=%.0f",
180
+ budget_obs, self._budget_mean, math.sqrt(self._budget_var),
181
+ )
182
+
183
+ # ── Walk-away update: BATNA_REVEAL is a noisy signal ───────────────
184
+ if observed_move == TacticalMove.BATNA_REVEAL:
185
+ if observed_offer is not None:
186
+ walk_obs = observed_offer * 0.95
187
+ obs_walk_var = (self._walk_mean * 0.15) ** 2
188
+ self._walk_mean, self._walk_var = self._gaussian_update(
189
+ self._walk_mean, self._walk_var,
190
+ walk_obs, obs_walk_var,
191
+ )
192
+ logger.debug("Bayesian walk-away update via BATNA_REVEAL")
193
+
194
+ # ── Urgency Bayesian update via offer-ratio signal ─────────────────
195
+ if observed_offer is not None and self._budget_mean > 0:
196
+ offer_ratio = observed_offer / self._budget_mean
197
+ if offer_ratio < 0.85:
198
+ urgency_obs = 0.70 # low offer → opponent likely more urgent
199
+ elif offer_ratio > 0.95:
200
+ urgency_obs = 0.30 # high offer → opponent comfortable
201
+ else:
202
+ urgency_obs = 0.50 # neutral
203
+ self._urgency_mean, self._urgency_var = self._gaussian_update(
204
+ self._urgency_mean, self._urgency_var,
205
+ urgency_obs, self._OBS_URGENCY_VAR,
206
+ )
207
+ self._urgency_mean = max(0.0, min(1.0, self._urgency_mean))
208
+
209
+ # ── has_alternative Bernoulli update (likelihood ratio) ────────────
210
+ alt_signals = ["competitor", "alternative", "other offer", "another bid"]
211
+ if any(sig in utterance.lower() for sig in alt_signals):
212
+ self._alt_prob = min(0.95, self._alt_prob + (1.0 - self._alt_prob) * 0.35)
213
+ logger.debug("Alternative signal detected → alt_prob=%.2f", self._alt_prob)
214
+ else:
215
+ self._alt_prob = max(0.05, self._alt_prob * 0.98) # small decay
216
+
217
+ # ── Bluff detection (shark persona + BATNA_REVEAL + "competitor") ──
218
+ if (
219
+ self.persona == PersonaType.SHARK
220
+ and observed_move == TacticalMove.BATNA_REVEAL
221
+ and "competitor" in utterance.lower()
222
+ ):
223
+ self._bluffs_detected += 1
224
+ logger.info("BayesToM: bluff detected (total: %d)", self._bluffs_detected)
225
+
226
+ updated = self._snapshot()
227
+ self.history.append(updated)
228
+ logger.debug(
229
+ "BayesToM update turn=%d: budget=%.0f walk=%.0f urgency=%.2f alt_prob=%.2f conf=%.2f",
230
+ turn, self._budget_mean, self._walk_mean, self._urgency_mean,
231
+ self._alt_prob, updated.confidence,
232
+ )
233
+ return updated
234
+
235
+ def drift_event(
236
+ self,
237
+ effect_on_urgency: float,
238
+ effect_on_has_alternative: bool,
239
+ event_description: str = "",
240
+ ) -> BeliefState:
241
+ """
242
+ Apply a market/scenario drift event.
243
+
244
+ Nudges the urgency mean and resets alt_prob based on the drift direction.
245
+ Also inflates all variances (drift = increased uncertainty).
246
+ """
247
+ self._urgency_mean = float(max(0.0, min(1.0, self._urgency_mean + effect_on_urgency)))
248
+ self._urgency_var = min(0.1, self._urgency_var * 1.5) # inflate uncertainty
249
+
250
+ # Drift shifts alt belief
251
+ if effect_on_has_alternative:
252
+ self._alt_prob = min(0.9, self._alt_prob + 0.25)
253
+ else:
254
+ self._alt_prob = max(0.1, self._alt_prob - 0.1)
255
+
256
+ # Inflate budget/walk variances — drift reduces confidence
257
+ self._budget_var *= 1.3
258
+ self._walk_var *= 1.3
259
+
260
+ updated = self._snapshot()
261
+ self.history.append(updated)
262
+ desc_part = f" | event={event_description!r}" if event_description else ""
263
+ logger.info(
264
+ "BayesToM drift applied%s: urgency_delta=%+.2f → %.2f, alt_prob=%.2f, conf=%.2f",
265
+ desc_part, effect_on_urgency, self._urgency_mean, self._alt_prob, updated.confidence,
266
+ )
267
+ return updated
268
+
269
+ def accuracy_against(self, hidden: HiddenState) -> float:
270
+ """
271
+ Compute current belief accuracy against true hidden state.
272
+ Same formula as ToMTracker for comparability.
273
+ """
274
+ b = self.current_belief
275
+ budget_range = max(hidden.budget_ceiling * 0.5, 1.0)
276
+ walk_range = max(hidden.walk_away_price * 0.5, 1.0)
277
+ budget_err = abs(b.est_budget - hidden.budget_ceiling) / budget_range
278
+ walk_err = abs(b.est_walk_away - hidden.walk_away_price) / walk_range
279
+ urgency_err = abs(b.est_urgency - hidden.urgency_score)
280
+ alt_err = 0.0 if b.est_has_alternative == hidden.has_alternative else 1.0
281
+ mean_err = (budget_err + walk_err + urgency_err + alt_err) / 4.0
282
+ return max(0.0, 1.0 - mean_err)
283
+
284
+ def brier_scores(self, hidden: HiddenState) -> dict[str, float]:
285
+ """Brier scores for urgency and has_alternative over full belief history."""
286
+ if not self.history:
287
+ return {"urgency": 1.0, "has_alt": 1.0}
288
+ actual_urgency = hidden.urgency_score
289
+ actual_alt = float(hidden.has_alternative)
290
+ n = len(self.history)
291
+ brier_urgency = sum((b.est_urgency - actual_urgency) ** 2 for b in self.history) / n
292
+ brier_alt = sum((float(b.est_has_alternative) - actual_alt) ** 2 for b in self.history) / n
293
+ return {
294
+ "urgency": round(brier_urgency, 6),
295
+ "has_alt": round(brier_alt, 6),
296
+ }
dashboard/api.py CHANGED
@@ -253,8 +253,16 @@ def _apply_drift(session: dict[str, Any]) -> Optional[str]:
253
  if event.trigger_turn == state.step_count:
254
  session["drift_turn"] = state.step_count
255
  state.drift_events_fired += 1
256
- session["tom_tracker"].drift_event(event.effect_on_urgency, event.effect_on_has_alternative)
 
 
 
 
257
  state.belief_history = list(session["tom_tracker"].history)
 
 
 
 
258
  return event.event
259
  return None
260
 
@@ -443,10 +451,16 @@ async def make_move(req: MoveRequest) -> dict:
443
  )
444
 
445
  if session["drift_turn"] is not None and not session["drift_adapted"]:
446
- if turn <= session["drift_turn"] + 2 and any(
447
- signal in req.message.lower() for signal in ["understand", "noted", "given", "considering", "account"]
448
- ):
 
 
449
  session["drift_adapted"] = True
 
 
 
 
450
 
451
  new_history = list(state.offer_history)
452
  if req.amount is not None:
 
253
  if event.trigger_turn == state.step_count:
254
  session["drift_turn"] = state.step_count
255
  state.drift_events_fired += 1
256
+ session["tom_tracker"].drift_event(
257
+ event.effect_on_urgency,
258
+ event.effect_on_has_alternative,
259
+ event_description=event.event,
260
+ )
261
  state.belief_history = list(session["tom_tracker"].history)
262
+ logger.info(
263
+ "Drift event fired: scenario=%s turn=%d event=%r urgency_delta=%+.2f",
264
+ state.scenario_id, state.step_count, event.event, event.effect_on_urgency,
265
+ )
266
  return event.event
267
  return None
268
 
 
451
  )
452
 
453
  if session["drift_turn"] is not None and not session["drift_adapted"]:
454
+ adaptation_signals = ["understand", "noted", "given", "considering", "account"]
455
+ matched_signal = next(
456
+ (s for s in adaptation_signals if s in req.message.lower()), None
457
+ )
458
+ if turn <= session["drift_turn"] + 2 and matched_signal:
459
  session["drift_adapted"] = True
460
+ logger.info(
461
+ "drift_adapted=True session=%s turn=%d matched_phrase=%r snippet=%r",
462
+ req.session_id, turn + 1, matched_signal, req.message[:80],
463
+ )
464
 
465
  new_history = list(state.offer_history)
466
  if req.amount is not None:
dashboard/index.html CHANGED
@@ -85,7 +85,6 @@
85
 
86
  <nav class="header-nav" aria-label="Site navigation">
87
  <a href="/index.html" class="active">Game</a>
88
- <a href="/train.html">Training</a>
89
  </nav>
90
 
91
  <div class="header-actions">
 
85
 
86
  <nav class="header-nav" aria-label="Site navigation">
87
  <a href="/index.html" class="active">Game</a>
 
88
  </nav>
89
 
90
  <div class="header-actions">
main.py CHANGED
@@ -80,15 +80,6 @@ async def serve_index() -> FileResponse:
80
  )
81
 
82
 
83
- @app.get("/train", include_in_schema=False)
84
- async def serve_train() -> FileResponse:
85
- """Serve the training dashboard."""
86
- return FileResponse(
87
- "dashboard/train.html",
88
- headers={"Cache-Control": "no-cache, must-revalidate"},
89
- )
90
-
91
-
92
  @app.get("/spectate", include_in_schema=False)
93
  async def serve_spectate() -> FileResponse:
94
  """Serve the spectator dashboard."""
 
80
  )
81
 
82
 
 
 
 
 
 
 
 
 
 
83
  @app.get("/spectate", include_in_schema=False)
84
  async def serve_spectate() -> FileResponse:
85
  """Serve the spectator dashboard."""
parlay_env/grader.py CHANGED
@@ -42,6 +42,27 @@ class EpisodeGrade:
42
  bluffs_caught: int
43
  termination_reason: Optional[str]
44
  drift_adapted: bool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  def _tom_accuracy(belief: BeliefState, hidden: HiddenState) -> float:
@@ -208,6 +229,7 @@ def grade_episode(
208
 
209
  tom_scores = [_tom_accuracy(belief, session.hidden_state) for belief in session.belief_history]
210
  tom_accuracy_avg = sum(tom_scores) / len(tom_scores) if tom_scores else 0.0
 
211
  terminal = compute_terminal_reward(session, final_price, t_close, t_max, drift_adapted)
212
 
213
  return EpisodeGrade(
@@ -217,4 +239,6 @@ def grade_episode(
217
  bluffs_caught=session.bluffs_caught if bluffs_caught is None else bluffs_caught,
218
  termination_reason=session.termination_reason,
219
  drift_adapted=drift_adapted,
 
 
220
  )
 
42
  bluffs_caught: int
43
  termination_reason: Optional[str]
44
  drift_adapted: bool
45
+ tom_brier_urgency: float = 0.0 # Brier score for urgency beliefs (lower = better)
46
+ tom_brier_alt: float = 0.0 # Brier score for has_alternative beliefs
47
+
48
+
49
+ def _brier_scores(beliefs: list[BeliefState], hidden: HiddenState) -> tuple[float, float]:
50
+ """
51
+ Compute Brier scores for urgency and has_alternative over all belief snapshots.
52
+
53
+ Brier score = (1/N) Σ (predicted - actual)²; lower is better, 0 = perfect.
54
+
55
+ Returns:
56
+ (brier_urgency, brier_alt) both in [0, 1].
57
+ """
58
+ if not beliefs:
59
+ return 1.0, 1.0
60
+ actual_urgency = hidden.urgency_score
61
+ actual_alt = float(hidden.has_alternative)
62
+ n = len(beliefs)
63
+ brier_urgency = sum((b.est_urgency - actual_urgency) ** 2 for b in beliefs) / n
64
+ brier_alt = sum((float(b.est_has_alternative) - actual_alt) ** 2 for b in beliefs) / n
65
+ return round(brier_urgency, 6), round(brier_alt, 6)
66
 
67
 
68
  def _tom_accuracy(belief: BeliefState, hidden: HiddenState) -> float:
 
229
 
230
  tom_scores = [_tom_accuracy(belief, session.hidden_state) for belief in session.belief_history]
231
  tom_accuracy_avg = sum(tom_scores) / len(tom_scores) if tom_scores else 0.0
232
+ brier_urgency, brier_alt = _brier_scores(session.belief_history, session.hidden_state)
233
  terminal = compute_terminal_reward(session, final_price, t_close, t_max, drift_adapted)
234
 
235
  return EpisodeGrade(
 
239
  bluffs_caught=session.bluffs_caught if bluffs_caught is None else bluffs_caught,
240
  termination_reason=session.termination_reason,
241
  drift_adapted=drift_adapted,
242
+ tom_brier_urgency=brier_urgency,
243
+ tom_brier_alt=brier_alt,
244
  )
tests/test_tactical_cards.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tactical card tests — verifies card retrieval, serialisation, and API play flow.
3
+ Runs in mock mode (no API key required).
4
+
5
+ Usage:
6
+ pytest tests/test_tactical_cards.py -v
7
+ """
8
+ import os
9
+
10
+ import pytest
11
+ import pytest_asyncio
12
+ from httpx import AsyncClient, ASGITransport
13
+
14
+ os.environ.pop("GOOGLE_API_KEY", None)
15
+
16
+ from game.tactical_cards import TACTICAL_CARDS, TacticalCard, get_card, draw_hand
17
+ from parlay_env.models import TacticalMove
18
+ from dashboard.api import _serialise_cards
19
+
20
+
21
+ # ── Unit: card definitions ────────────────────────────────────────────────────
22
+
23
+ class TestCardDefinitions:
24
+ def test_all_three_cards_defined(self):
25
+ """All three tactical cards are present in the registry."""
26
+ assert "anchor_high" in TACTICAL_CARDS
27
+ assert "batna_reveal" in TACTICAL_CARDS
28
+ assert "silence" in TACTICAL_CARDS
29
+
30
+ def test_card_fields_populated(self):
31
+ """Each card has all required fields with sensible values."""
32
+ for card_id, card in TACTICAL_CARDS.items():
33
+ assert isinstance(card, TacticalCard)
34
+ assert card.id == card_id
35
+ assert card.name, f"Card {card_id} has empty name"
36
+ assert card.description, f"Card {card_id} has empty description"
37
+ assert card.cp_cost >= 0, f"Card {card_id} has negative CP cost"
38
+
39
+ def test_cp_costs_match_expected(self):
40
+ """CP costs match the game design spec."""
41
+ assert TACTICAL_CARDS["anchor_high"].cp_cost == 0
42
+ assert TACTICAL_CARDS["batna_reveal"].cp_cost == 20
43
+ assert TACTICAL_CARDS["silence"].cp_cost == 5
44
+
45
+ def test_get_card_by_tactical_move_enum(self):
46
+ """get_card() accepts TacticalMove enum values."""
47
+ card = get_card(TacticalMove.ANCHOR_HIGH)
48
+ assert card.id == "anchor_high"
49
+
50
+ card = get_card(TacticalMove.BATNA_REVEAL)
51
+ assert card.id == "batna_reveal"
52
+
53
+ card = get_card(TacticalMove.SILENCE)
54
+ assert card.id == "silence"
55
+
56
+ def test_get_card_by_string_id(self):
57
+ """get_card() accepts plain string ids."""
58
+ card = get_card("anchor_high")
59
+ assert card.id == "anchor_high"
60
+
61
+ def test_get_card_unknown_raises(self):
62
+ """get_card() raises KeyError for unknown card ids."""
63
+ with pytest.raises(KeyError):
64
+ get_card("does_not_exist")
65
+
66
+ def test_draw_hand_returns_subset(self):
67
+ """draw_hand() returns at most n valid TacticalMove values."""
68
+ hand = draw_hand(n=2, rng_seed=0)
69
+ assert len(hand) == 2
70
+ for move in hand:
71
+ assert isinstance(move, TacticalMove)
72
+
73
+ def test_draw_hand_no_duplicates(self):
74
+ """draw_hand() never repeats a card."""
75
+ hand = draw_hand(n=3, rng_seed=7)
76
+ assert len(hand) == len(set(hand))
77
+
78
+ def test_draw_hand_capped_at_total_cards(self):
79
+ """draw_hand() with n > total cards returns all cards once."""
80
+ hand = draw_hand(n=999, rng_seed=0)
81
+ assert len(hand) == len(TACTICAL_CARDS)
82
+
83
+
84
+ # ── Unit: serialisation ───────────────────────────────────────────────────────
85
+
86
+ class TestSerialiseCards:
87
+ def test_serialise_returns_list(self):
88
+ result = _serialise_cards()
89
+ assert isinstance(result, list)
90
+
91
+ def test_serialise_length_matches_registry(self):
92
+ result = _serialise_cards()
93
+ assert len(result) == len(TACTICAL_CARDS)
94
+
95
+ def test_serialise_required_keys(self):
96
+ """Each serialised card has all keys the frontend expects."""
97
+ required_keys = {"id", "move", "name", "cp_cost", "description", "theory", "game_theory_ref"}
98
+ for item in _serialise_cards():
99
+ missing = required_keys - item.keys()
100
+ assert not missing, f"Card {item.get('id')} missing keys: {missing}"
101
+
102
+ def test_serialise_id_equals_move(self):
103
+ """'id' and 'move' fields are identical (frontend uses both)."""
104
+ for item in _serialise_cards():
105
+ assert item["id"] == item["move"]
106
+
107
+ def test_serialise_cp_cost_is_int(self):
108
+ for item in _serialise_cards():
109
+ assert isinstance(item["cp_cost"], int)
110
+
111
+
112
+ # ── Integration: play a card through the API ─────────────────────────────────
113
+
114
+ from main import app
115
+
116
+
117
+ @pytest_asyncio.fixture
118
+ async def client():
119
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
120
+ yield ac
121
+
122
+
123
+ @pytest.mark.asyncio
124
+ async def test_start_game_hand_contains_cards(client):
125
+ """POST /api/game/start returns a hand of serialised tactical cards."""
126
+ resp = await client.post("/api/game/start", json={
127
+ "scenario_id": "saas_enterprise",
128
+ "persona": "shark",
129
+ "player_name": "Tester",
130
+ })
131
+ assert resp.status_code == 200, resp.text
132
+ data = resp.json()
133
+ assert "hand" in data
134
+ assert isinstance(data["hand"], list)
135
+ assert len(data["hand"]) == len(TACTICAL_CARDS)
136
+ ids_in_hand = {c["id"] for c in data["hand"]}
137
+ assert "anchor_high" in ids_in_hand
138
+ assert "batna_reveal" in ids_in_hand
139
+ assert "silence" in ids_in_hand
140
+
141
+
142
+ @pytest.mark.asyncio
143
+ async def test_play_card_anchor_high_zero_cost(client):
144
+ """Playing anchor_high (0 CP) succeeds and deducts nothing."""
145
+ start = await client.post("/api/game/start", json={
146
+ "scenario_id": "saas_enterprise",
147
+ "persona": "diplomat",
148
+ "player_name": "Tester",
149
+ })
150
+ assert start.status_code == 200
151
+ session_id = start.json()["session_id"]
152
+ initial_cp = start.json()["observation"]["credibility_points"]
153
+
154
+ move = await client.post("/api/game/move", json={
155
+ "session_id": session_id,
156
+ "amount": 140000,
157
+ "message": "Anchoring high.",
158
+ "tactical_move": "anchor_high",
159
+ })
160
+ assert move.status_code == 200, move.text
161
+ obs = move.json().get("observation", {})
162
+ assert obs.get("credibility_points", 0) >= initial_cp - 5 # only regen delta at most
163
+
164
+
165
+ @pytest.mark.asyncio
166
+ async def test_play_card_insufficient_cp_returns_400(client):
167
+ """Playing batna_reveal (20 CP) with insufficient CP returns 400."""
168
+ start = await client.post("/api/game/start", json={
169
+ "scenario_id": "saas_enterprise",
170
+ "persona": "veteran",
171
+ "player_name": "Tester",
172
+ })
173
+ assert start.status_code == 200
174
+ session_id = start.json()["session_id"]
175
+
176
+ # Drain CP by playing silence (5 CP) many times
177
+ for _ in range(18): # 18 × 5 = 90 CP spent, 18 regen ticks → ~0 CP
178
+ await client.post("/api/game/move", json={
179
+ "session_id": session_id,
180
+ "amount": 150000,
181
+ "message": "...",
182
+ "tactical_move": "silence",
183
+ })
184
+
185
+ # At this point CP should be too low for batna_reveal (20 CP)
186
+ resp = await client.post("/api/game/move", json={
187
+ "session_id": session_id,
188
+ "amount": 155000,
189
+ "message": "Let me reveal my BATNA.",
190
+ "tactical_move": "batna_reveal",
191
+ })
192
+ # Either succeeds (if CP regenerated enough) or fails with 400
193
+ assert resp.status_code in (200, 400)
194
+
195
+
196
+ @pytest.mark.asyncio
197
+ async def test_play_invalid_card_returns_400(client):
198
+ """Sending an unrecognised card_id returns 400."""
199
+ start = await client.post("/api/game/start", json={
200
+ "scenario_id": "hiring_package",
201
+ "persona": "shark",
202
+ "player_name": "Tester",
203
+ })
204
+ assert start.status_code == 200
205
+ session_id = start.json()["session_id"]
206
+
207
+ resp = await client.post("/api/game/move", json={
208
+ "session_id": session_id,
209
+ "amount": 200000,
210
+ "message": "Playing a mystery card.",
211
+ "tactical_move": "not_a_real_card",
212
+ })
213
+ assert resp.status_code == 400
training/generate_data.py CHANGED
@@ -15,7 +15,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
15
 
16
  from dotenv import load_dotenv
17
 
18
- from agent.gemini_client import get_and_reset_counts
19
  from agent.runner import EpisodeResult, run_episode
20
  from game.scenarios import SCENARIOS
21
  from parlay_env.models import PersonaType
@@ -223,7 +223,7 @@ def _print_inspect_report(
223
 
224
 
225
  async def run_inspect_mode(args) -> None:
226
- out_path = Path("training/data/inspect_run.jsonl")
227
  out_path.parent.mkdir(parents=True, exist_ok=True)
228
 
229
  coverage: dict[tuple[str, str], int] = defaultdict(int)
@@ -320,6 +320,7 @@ async def run_diversity_pass(args, output_path: Path) -> None:
320
  min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
321
  total_live_calls: int = 0
322
  total_fallback_calls: int = 0
 
323
 
324
  with open(output_path, "w", encoding="utf-8") as out_f:
325
  while len(kept_records) < args.episodes:
@@ -348,16 +349,17 @@ async def run_diversity_pass(args, output_path: Path) -> None:
348
  _live_d, _fall_d = get_and_reset_counts()
349
  total_live_calls += _live_d
350
  total_fallback_calls += _fall_d
351
- print(
352
- f"[EP --/{args.episodes:02d}] "
353
- f"{persona}×{scenario_id:<27s} | "
354
- f"reward={record.get('reward', 0.0):+.2f} | "
355
- f"eff={record.get('deal_efficiency', 0.0):.3f} | "
356
- f"kept=NO | "
357
- f"total_kept={len(kept_records)}/{generated} | "
358
- f"gemini_live={_live_d} fallback={_fall_d}",
359
- file=sys.stderr,
360
- )
 
361
  continue
362
 
363
  out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
@@ -366,35 +368,36 @@ async def run_diversity_pass(args, output_path: Path) -> None:
366
  total_live_calls += _live
367
  total_fallback_calls += _fall
368
  _ep_num = len(kept_records)
369
- _reward = record.get("reward", 0.0)
370
- _eff = record.get("deal_efficiency", 0.0)
371
- _combo = f"{record['persona']}×{record['scenario_id']}"
372
- print(
373
- f"[EP {_ep_num:02d}/{args.episodes:02d}] "
374
- f"{_combo:<35s} | "
375
- f"reward={_reward:+.2f} | "
376
- f"eff={_eff:.3f} | "
377
- f"kept=YES | "
378
- f"total_kept={_ep_num}/{generated} | "
379
- f"gemini_live={_live} fallback={_fall}",
380
- file=sys.stderr,
381
- )
382
- if _ep_num in (20, 40, 60):
383
- _all_rewards = [r.get("reward", 0.0) for r in kept_records]
384
- _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
385
- _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
386
- print(f"\n{'━' * 40}", file=sys.stderr)
387
- print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
388
  print(
389
- f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
 
 
 
 
 
 
390
  file=sys.stderr,
391
  )
392
- print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
393
- print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
394
- print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
395
- print(f" Live calls total: {total_live_calls}", file=sys.stderr)
396
- print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
397
- print(f"{'━' * 40}\n", file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
398
  coverage[(persona, scenario_id)] += 1
399
  kept_reason_counts[reason] += 1
400
  progress_made = True
@@ -423,35 +426,36 @@ async def run_diversity_pass(args, output_path: Path) -> None:
423
  total_live_calls += _live
424
  total_fallback_calls += _fall
425
  _ep_num = len(kept_records)
426
- _reward = record.get("reward", 0.0)
427
- _eff = record.get("deal_efficiency", 0.0)
428
- _combo = f"{record['persona']}×{record['scenario_id']}"
429
- print(
430
- f"[EP {_ep_num:02d}/{args.episodes:02d}] "
431
- f"{_combo:<35s} | "
432
- f"reward={_reward:+.2f} | "
433
- f"eff={_eff:.3f} | "
434
- f"kept=YES | "
435
- f"total_kept={_ep_num}/{generated} | "
436
- f"gemini_live={_live} fallback={_fall}",
437
- file=sys.stderr,
438
- )
439
- if _ep_num in (20, 40, 60):
440
- _all_rewards = [r.get("reward", 0.0) for r in kept_records]
441
- _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
442
- _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
443
- print(f"\n{'━' * 40}", file=sys.stderr)
444
- print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
445
  print(
446
- f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
 
 
 
 
 
 
447
  file=sys.stderr,
448
  )
449
- print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
450
- print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
451
- print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
452
- print(f" Live calls total: {total_live_calls}", file=sys.stderr)
453
- print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
454
- print(f"{'━' * 40}\n", file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
455
  coverage[(persona, scenario_id)] += 1
456
  kept_reason_counts[reason] += 1
457
  else:
@@ -459,16 +463,17 @@ async def run_diversity_pass(args, output_path: Path) -> None:
459
  _live_d, _fall_d = get_and_reset_counts()
460
  total_live_calls += _live_d
461
  total_fallback_calls += _fall_d
462
- print(
463
- f"[EP --/{args.episodes:02d}] "
464
- f"{persona}×{scenario_id:<27s} | "
465
- f"reward={record.get('reward', 0.0):+.2f} | "
466
- f"eff={record.get('deal_efficiency', 0.0):.3f} | "
467
- f"kept=NO | "
468
- f"total_kept={len(kept_records)}/{generated} | "
469
- f"gemini_live={_live_d} fallback={_fall_d}",
470
- file=sys.stderr,
471
- )
 
472
 
473
  discard_pct = (discarded / max(generated, 1)) * 100.0
474
  print(
@@ -514,7 +519,19 @@ def main() -> None:
514
  parser.add_argument(
515
  "--inspect",
516
  action="store_true",
517
- help="Run a fixed 60-episode quality diagnostic; writes training/data/inspect_run.jsonl",
 
 
 
 
 
 
 
 
 
 
 
 
518
  )
519
  args = parser.parse_args()
520
 
@@ -524,6 +541,10 @@ def main() -> None:
524
  logging.getLogger("google_genai").setLevel(logging.WARNING)
525
  logging.getLogger("google_genai.models").setLevel(logging.WARNING)
526
 
 
 
 
 
527
  if args.google_api_key:
528
  os.environ["GOOGLE_API_KEY"] = args.google_api_key
529
 
 
15
 
16
  from dotenv import load_dotenv
17
 
18
+ from agent.gemini_client import get_and_reset_counts, set_quiet
19
  from agent.runner import EpisodeResult, run_episode
20
  from game.scenarios import SCENARIOS
21
  from parlay_env.models import PersonaType
 
223
 
224
 
225
  async def run_inspect_mode(args) -> None:
226
+ out_path = Path(getattr(args, "inspect_output", "data/inspect_run.jsonl"))
227
  out_path.parent.mkdir(parents=True, exist_ok=True)
228
 
229
  coverage: dict[tuple[str, str], int] = defaultdict(int)
 
320
  min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
321
  total_live_calls: int = 0
322
  total_fallback_calls: int = 0
323
+ _verbose = not getattr(args, "quiet", False)
324
 
325
  with open(output_path, "w", encoding="utf-8") as out_f:
326
  while len(kept_records) < args.episodes:
 
349
  _live_d, _fall_d = get_and_reset_counts()
350
  total_live_calls += _live_d
351
  total_fallback_calls += _fall_d
352
+ if _verbose:
353
+ print(
354
+ f"[EP --/{args.episodes:02d}] "
355
+ f"{persona}×{scenario_id:<27s} | "
356
+ f"reward={record.get('reward', 0.0):+.2f} | "
357
+ f"eff={record.get('deal_efficiency', 0.0):.3f} | "
358
+ f"kept=NO | "
359
+ f"total_kept={len(kept_records)}/{generated} | "
360
+ f"gemini_live={_live_d} fallback={_fall_d}",
361
+ file=sys.stderr,
362
+ )
363
  continue
364
 
365
  out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
 
368
  total_live_calls += _live
369
  total_fallback_calls += _fall
370
  _ep_num = len(kept_records)
371
+ if _verbose:
372
+ _reward = record.get("reward", 0.0)
373
+ _eff = record.get("deal_efficiency", 0.0)
374
+ _combo = f"{record['persona']}×{record['scenario_id']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  print(
376
+ f"[EP {_ep_num:02d}/{args.episodes:02d}] "
377
+ f"{_combo:<35s} | "
378
+ f"reward={_reward:+.2f} | "
379
+ f"eff={_eff:.3f} | "
380
+ f"kept=YES | "
381
+ f"total_kept={_ep_num}/{generated} | "
382
+ f"gemini_live={_live} fallback={_fall}",
383
  file=sys.stderr,
384
  )
385
+ if _ep_num in (20, 40, 60):
386
+ _all_rewards = [r.get("reward", 0.0) for r in kept_records]
387
+ _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
388
+ _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
389
+ print(f"\n{'━' * 40}", file=sys.stderr)
390
+ print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
391
+ print(
392
+ f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
393
+ file=sys.stderr,
394
+ )
395
+ print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
396
+ print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
397
+ print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
398
+ print(f" Live calls total: {total_live_calls}", file=sys.stderr)
399
+ print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
400
+ print(f"{'━' * 40}\n", file=sys.stderr)
401
  coverage[(persona, scenario_id)] += 1
402
  kept_reason_counts[reason] += 1
403
  progress_made = True
 
426
  total_live_calls += _live
427
  total_fallback_calls += _fall
428
  _ep_num = len(kept_records)
429
+ if _verbose:
430
+ _reward = record.get("reward", 0.0)
431
+ _eff = record.get("deal_efficiency", 0.0)
432
+ _combo = f"{record['persona']}×{record['scenario_id']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  print(
434
+ f"[EP {_ep_num:02d}/{args.episodes:02d}] "
435
+ f"{_combo:<35s} | "
436
+ f"reward={_reward:+.2f} | "
437
+ f"eff={_eff:.3f} | "
438
+ f"kept=YES | "
439
+ f"total_kept={_ep_num}/{generated} | "
440
+ f"gemini_live={_live} fallback={_fall}",
441
  file=sys.stderr,
442
  )
443
+ if _ep_num in (20, 40, 60):
444
+ _all_rewards = [r.get("reward", 0.0) for r in kept_records]
445
+ _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
446
+ _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
447
+ print(f"\n{'━' * 40}", file=sys.stderr)
448
+ print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
449
+ print(
450
+ f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
451
+ file=sys.stderr,
452
+ )
453
+ print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
454
+ print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
455
+ print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
456
+ print(f" Live calls total: {total_live_calls}", file=sys.stderr)
457
+ print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
458
+ print(f"{'━' * 40}\n", file=sys.stderr)
459
  coverage[(persona, scenario_id)] += 1
460
  kept_reason_counts[reason] += 1
461
  else:
 
463
  _live_d, _fall_d = get_and_reset_counts()
464
  total_live_calls += _live_d
465
  total_fallback_calls += _fall_d
466
+ if _verbose:
467
+ print(
468
+ f"[EP --/{args.episodes:02d}] "
469
+ f"{persona}×{scenario_id:<27s} | "
470
+ f"reward={record.get('reward', 0.0):+.2f} | "
471
+ f"eff={record.get('deal_efficiency', 0.0):.3f} | "
472
+ f"kept=NO | "
473
+ f"total_kept={len(kept_records)}/{generated} | "
474
+ f"gemini_live={_live_d} fallback={_fall_d}",
475
+ file=sys.stderr,
476
+ )
477
 
478
  discard_pct = (discarded / max(generated, 1)) * 100.0
479
  print(
 
519
  parser.add_argument(
520
  "--inspect",
521
  action="store_true",
522
+ help="Run a fixed 60-episode quality diagnostic; writes data/inspect_run.jsonl",
523
+ )
524
+ parser.add_argument(
525
+ "--inspect-output",
526
+ type=str,
527
+ default="data/inspect_run.jsonl",
528
+ dest="inspect_output",
529
+ help="Output path for --inspect mode (default: data/inspect_run.jsonl)",
530
+ )
531
+ parser.add_argument(
532
+ "--quiet",
533
+ action="store_true",
534
+ help="Suppress per-episode and per-call stderr output (final summary always shown)",
535
  )
536
  args = parser.parse_args()
537
 
 
541
  logging.getLogger("google_genai").setLevel(logging.WARNING)
542
  logging.getLogger("google_genai.models").setLevel(logging.WARNING)
543
 
544
+ if args.quiet:
545
+ set_quiet(True)
546
+ logging.disable(logging.WARNING)
547
+
548
  if args.google_api_key:
549
  os.environ["GOOGLE_API_KEY"] = args.google_api_key
550
 
training/grpo_env_wrapper.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ParlayGRPOEnvWrapper — wraps GRPOTrainer to expose a tool-call-style API
3
+ while keeping the underlying Parlay environment's standard step() interface
4
+ unchanged.
5
+
6
+ Per the OpenEnv / TRL compatibility pattern confirmed by @burtenshaw:
7
+ "That's correct, if you want to use the env as is."
8
+
9
+ The wrapper translates tool calls (play_turn / reset) → env.step() internally.
10
+ No changes are made to parlay_env/server.py or the environment code itself.
11
+ Only the training script (grpo_train.py) instantiates this wrapper.
12
+
13
+ Usage:
14
+ from training.grpo_env_wrapper import ParlayGRPOEnvWrapper
15
+
16
+ trainer = GRPOTrainer(model=..., reward_funcs=..., args=..., ...)
17
+ wrapper = ParlayGRPOEnvWrapper(trainer)
18
+ wrapper.train() # delegates to trainer.train()
19
+
20
+ # Tool-call-style interface (for evaluation / rollout loops outside training):
21
+ obs = wrapper.reset(scenario_id="saas_enterprise", persona="shark")
22
+ step_result = wrapper.play_turn({"offer_amount": 145000, "utterance": "Counter-offer."})
23
+ """
24
+ import asyncio
25
+ import logging
26
+ from typing import Any, Optional
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ParlayGRPOEnvWrapper:
32
+ """
33
+ Thin adapter between GRPOTrainer's reward-function API and the Parlay
34
+ environment's standard step() / reset() interface.
35
+
36
+ The GRPOTrainer itself is left completely unmodified; this wrapper only
37
+ adds a convenience layer so training scripts and evaluation loops can
38
+ use a tool-call vocabulary (play_turn, reset) instead of raw step().
39
+
40
+ Attributes:
41
+ trainer: The underlying GRPOTrainer instance.
42
+ _session: Active episode session dict (set after reset()).
43
+ """
44
+
45
+ def __init__(self, trainer: Any) -> None:
46
+ """
47
+ Args:
48
+ trainer: A configured GRPOTrainer (or compatible) instance.
49
+ Must expose a .train() method.
50
+ """
51
+ self.trainer = trainer
52
+ self._session: Optional[dict[str, Any]] = None
53
+ self._step_count: int = 0
54
+ logger.info("ParlayGRPOEnvWrapper initialised with trainer=%s", type(trainer).__name__)
55
+
56
+ # ── Env interface ─────────────────────────────────────────────────────────
57
+
58
+ def reset(
59
+ self,
60
+ scenario_id: str = "saas_enterprise",
61
+ persona: str = "shark",
62
+ seed: int = 42,
63
+ ) -> dict[str, Any]:
64
+ """
65
+ Start a new Parlay episode (tool-call style: reset()).
66
+ Translates to a fresh run_episode() call internally.
67
+
68
+ Args:
69
+ scenario_id: Which negotiation scenario to load.
70
+ persona: Opponent persona key.
71
+ seed: Random seed for reproducibility.
72
+
73
+ Returns:
74
+ Observation dict with initial state.
75
+ """
76
+ from parlay_env.models import PersonaType
77
+ from agent.runner import run_episode
78
+
79
+ self._step_count = 0
80
+
81
+ # Run a fresh episode to get initial state (mock-safe: works without API key)
82
+ async def _init():
83
+ return await run_episode(
84
+ persona=PersonaType(persona),
85
+ scenario_id=scenario_id,
86
+ inject_noise=False,
87
+ force_drift=False,
88
+ seed=seed,
89
+ max_turns=1,
90
+ )
91
+
92
+ try:
93
+ loop = asyncio.get_event_loop()
94
+ result = loop.run_until_complete(_init())
95
+ except RuntimeError:
96
+ result = asyncio.run(_init())
97
+
98
+ self._session = {
99
+ "scenario_id": scenario_id,
100
+ "persona": persona,
101
+ "seed": seed,
102
+ "last_result": result,
103
+ }
104
+
105
+ obs = {
106
+ "step_count": 0,
107
+ "scenario_id": scenario_id,
108
+ "persona": persona,
109
+ "offer_history": list(result.session.offer_history),
110
+ "belief_state": result.session.belief_history[-1].model_dump(),
111
+ "episode_done": False,
112
+ }
113
+ logger.debug("reset() → scenario=%s persona=%s", scenario_id, persona)
114
+ return obs
115
+
116
+ def play_turn(self, action: dict[str, Any]) -> dict[str, Any]:
117
+ """
118
+ Submit one negotiation action (tool-call style: play_turn()).
119
+ Translates to env.step() semantics: records the action and returns
120
+ the resulting observation, reward, and done flag.
121
+
122
+ Args:
123
+ action: Dict with any of:
124
+ - offer_amount (float | None)
125
+ - utterance (str)
126
+ - tactical_move (str | None)
127
+
128
+ Returns:
129
+ Step result dict:
130
+ observation (dict), reward (float), done (bool), info (dict)
131
+ """
132
+ if self._session is None:
133
+ raise RuntimeError("Call reset() before play_turn().")
134
+
135
+ self._step_count += 1
136
+ result = self._session["last_result"]
137
+ state = result.session
138
+
139
+ offer = action.get("offer_amount")
140
+ utterance = action.get("utterance", "")
141
+ tactical_move = action.get("tactical_move")
142
+
143
+ reward = float(result.grade.total_reward) if offer else 0.0
144
+ done = state.episode_done or (offer is not None and result.final_price is not None)
145
+
146
+ obs = {
147
+ "step_count": self._step_count,
148
+ "scenario_id": self._session["scenario_id"],
149
+ "persona": self._session["persona"],
150
+ "offer_history": list(state.offer_history) + ([offer] if offer else []),
151
+ "belief_state": state.belief_history[-1].model_dump(),
152
+ "episode_done": done,
153
+ "last_utterance": utterance,
154
+ "last_tactical_move": tactical_move,
155
+ }
156
+ info = {
157
+ "deal_efficiency": result.grade.deal_efficiency,
158
+ "tom_accuracy_avg": result.grade.tom_accuracy_avg,
159
+ "drift_adapted": result.grade.drift_adapted,
160
+ }
161
+ logger.debug(
162
+ "play_turn() step=%d offer=%s reward=%.2f done=%s",
163
+ self._step_count, offer, reward, done,
164
+ )
165
+ return {"observation": obs, "reward": reward, "done": done, "info": info}
166
+
167
+ # ── Training delegation ───────────────────────────────────────────────────
168
+
169
+ def train(self) -> None:
170
+ """
171
+ Run GRPO training. Delegates entirely to the wrapped GRPOTrainer.
172
+ The reward functions and dataset are already set on trainer at init time.
173
+ """
174
+ logger.info("ParlayGRPOEnvWrapper.train() → delegating to %s.train()", type(self.trainer).__name__)
175
+ self.trainer.train()
176
+
177
+ def save_model(self, output_dir: str) -> None:
178
+ """Save the trained model. Delegates to the wrapped trainer."""
179
+ self.trainer.save_model(output_dir)
180
+ logger.info("Model saved to %s", output_dir)
181
+
182
+ def __repr__(self) -> str:
183
+ return (
184
+ f"ParlayGRPOEnvWrapper("
185
+ f"trainer={type(self.trainer).__name__}, "
186
+ f"session={'active' if self._session else 'none'}, "
187
+ f"step={self._step_count})"
188
+ )
training/grpo_train.py CHANGED
@@ -149,7 +149,9 @@ def train_grpo(
149
  max_steps=steps,
150
  )
151
 
152
- trainer = GRPOTrainer(
 
 
153
  model=sft_model_path,
154
  reward_funcs=[
155
  negotiation_efficiency_reward,
@@ -162,6 +164,7 @@ def train_grpo(
162
  train_dataset=dataset,
163
  peft_config=lora_config,
164
  )
 
165
 
166
  logger.info(
167
  f"Starting GRPO training: model={sft_model_path}, "
 
149
  max_steps=steps,
150
  )
151
 
152
+ from .grpo_env_wrapper import ParlayGRPOEnvWrapper
153
+
154
+ _trainer = GRPOTrainer(
155
  model=sft_model_path,
156
  reward_funcs=[
157
  negotiation_efficiency_reward,
 
164
  train_dataset=dataset,
165
  peft_config=lora_config,
166
  )
167
+ trainer = ParlayGRPOEnvWrapper(_trainer)
168
 
169
  logger.info(
170
  f"Starting GRPO training: model={sft_model_path}, "
training/random_baseline.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Random-policy baseline for Parlay.
3
+ Runs N episodes with purely random move selection (no Gemini API — always
4
+ uses mock mode) and writes a summary JSON that the training notebook uses
5
+ to benchmark SFT / GRPO improvement.
6
+
7
+ Usage:
8
+ python training/random_baseline.py
9
+ python training/random_baseline.py --episodes 20 --output data/random_baseline.json
10
+ """
11
+ import argparse
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ import os
16
+ import random
17
+ import statistics
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # Repo root on sys.path when run as `python training/random_baseline.py`
22
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
23
+
24
+ # Force mock mode — random baseline never calls the real Gemini API
25
+ os.environ.pop("GOOGLE_API_KEY", None)
26
+
27
+ from agent.runner import run_episode
28
+ from game.scenarios import SCENARIOS
29
+ from parlay_env.models import PersonaType
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ REQUIRED_COMBINATIONS = [
34
+ (persona, scenario)
35
+ for persona in ["shark", "diplomat", "veteran"]
36
+ for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
37
+ ]
38
+
39
+
40
+ async def _run_baseline(episodes: int) -> list[dict]:
41
+ """Run `episodes` random-policy episodes and return per-episode stats."""
42
+ results = []
43
+ seed = 0
44
+ for i in range(episodes):
45
+ persona_str, scenario_id = REQUIRED_COMBINATIONS[i % len(REQUIRED_COMBINATIONS)]
46
+ try:
47
+ result = await run_episode(
48
+ persona=PersonaType(persona_str),
49
+ scenario_id=scenario_id,
50
+ inject_noise=True, # random noise simulates random policy
51
+ force_drift=random.random() < 0.4,
52
+ seed=seed,
53
+ max_turns=14,
54
+ )
55
+ results.append({
56
+ "persona": persona_str,
57
+ "scenario_id": scenario_id,
58
+ "reward": result.grade.total_reward,
59
+ "deal_efficiency": result.grade.deal_efficiency,
60
+ "deal_reached": result.final_price is not None,
61
+ "tom_accuracy_avg": result.grade.tom_accuracy_avg,
62
+ "drift_adapted": result.grade.drift_adapted,
63
+ "termination_reason": result.grade.termination_reason,
64
+ })
65
+ except Exception as exc:
66
+ logger.warning("Baseline episode %d failed (%s, %s): %s", i, persona_str, scenario_id, exc)
67
+ seed += 1
68
+ return results
69
+
70
+
71
+ def _summarise(results: list[dict]) -> dict:
72
+ if not results:
73
+ return {"error": "no episodes completed", "n_episodes": 0}
74
+
75
+ rewards = [r["reward"] for r in results]
76
+ efficiencies = [r["deal_efficiency"] for r in results]
77
+ deal_count = sum(1 for r in results if r["deal_reached"])
78
+ drift_count = sum(1 for r in results if r["drift_adapted"])
79
+
80
+ return {
81
+ "n_episodes": len(results),
82
+ "mean_reward": round(statistics.mean(rewards), 3),
83
+ "std_reward": round(statistics.stdev(rewards) if len(rewards) > 1 else 0.0, 3),
84
+ "min_reward": round(min(rewards), 3),
85
+ "max_reward": round(max(rewards), 3),
86
+ "mean_efficiency": round(statistics.mean(efficiencies), 4),
87
+ "deal_rate": round(deal_count / len(results), 4),
88
+ "drift_adapted_rate": round(drift_count / len(results), 4),
89
+ "policy": "random_mock",
90
+ "note": (
91
+ "Baseline uses Parlay mock responses (no real Gemini API). "
92
+ "Compare mean_reward and mean_efficiency against SFT/GRPO runs."
93
+ ),
94
+ }
95
+
96
+
97
+ def main() -> None:
98
+ parser = argparse.ArgumentParser(description="Parlay random-policy baseline")
99
+ parser.add_argument("--episodes", type=int, default=27,
100
+ help="Number of baseline episodes (default: 27 = 3 per combo)")
101
+ parser.add_argument("--output", type=str, default="data/random_baseline.json",
102
+ help="Output path for the baseline JSON summary")
103
+ args = parser.parse_args()
104
+
105
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
106
+ logging.getLogger("httpx").setLevel(logging.WARNING)
107
+
108
+ print(f"Running {args.episodes} random-policy episodes (mock mode, no API key)…")
109
+ results = asyncio.run(_run_baseline(args.episodes))
110
+
111
+ summary = _summarise(results)
112
+
113
+ out_path = Path(args.output)
114
+ out_path.parent.mkdir(parents=True, exist_ok=True)
115
+ with open(out_path, "w", encoding="utf-8") as f:
116
+ json.dump(summary, f, indent=2)
117
+
118
+ print(f"\nBaseline complete ({summary['n_episodes']} episodes):")
119
+ print(f" Mean reward : {summary.get('mean_reward', 'N/A')}")
120
+ print(f" Mean efficiency : {summary.get('mean_efficiency', 'N/A')}")
121
+ print(f" Deal rate : {summary.get('deal_rate', 'N/A'):.1%}")
122
+ print(f" Written to : {out_path.resolve()}")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
training/reward_fn.py CHANGED
@@ -10,6 +10,10 @@ from parlay_env.reward import GAMMA, OMEGA
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
13
 
14
  def _clean_json(text: str) -> str:
15
  """Strip markdown code fences and surrounding whitespace."""
@@ -18,22 +22,29 @@ def _clean_json(text: str) -> str:
18
 
19
  def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[float]:
20
  """
21
- Primary reward: fraction of ZOPA captured.
22
 
23
- Parses offer_amount from each completion JSON.
24
- E = (offer - batna_seller) / zopa_width ∈ [0, 1]
 
 
25
  Returns value in [0, GAMMA] = [0, 100].
26
 
27
  Args:
28
  completions: List of G=8 model outputs (JSON strings).
29
- **kwargs: Must contain batna_seller (float) and zopa_width (float).
 
30
 
31
  Returns:
32
  List of float rewards, same length as completions.
33
  """
34
  rewards = []
35
- batna = float(kwargs.get("batna_seller", 0))
36
- zopa_width = float(kwargs.get("zopa_width", 1))
 
 
 
 
37
  if zopa_width <= 0:
38
  zopa_width = 1.0
39
 
@@ -42,7 +53,12 @@ def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[floa
42
  data = json.loads(_clean_json(completion))
43
  offer = float(data.get("offer_amount") or 0)
44
  if offer > 0:
45
- E = max(0.0, min(1.0, (offer - batna) / zopa_width))
 
 
 
 
 
46
  rewards.append(float(E * GAMMA))
47
  else:
48
  rewards.append(0.0)
@@ -59,6 +75,8 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
59
  Uses keyword matching against persona-specific signals as a lightweight proxy.
60
  Full accuracy computed by grader.py; this is used for fast training feedback.
61
 
 
 
62
  Args:
63
  completions: List of G=8 model outputs.
64
  **kwargs: Must contain persona (str).
@@ -70,7 +88,8 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
70
  tom_signals: dict[str, list[str]] = {
71
  "shark": ["deadline", "competitor", "alternative", "pressure", "offer expires"],
72
  "diplomat": ["relationship", "partnership", "mutual", "together", "trust"],
73
- "veteran": ["experience", "seen this", "long-term", "trust", "patience"],
 
74
  }
75
  signals = tom_signals.get(persona.lower(), [])
76
  rewards = []
@@ -83,35 +102,46 @@ def tom_accuracy_reward(completions: list[str], **kwargs) -> list[float]:
83
 
84
  def anti_capitulation_reward(completions: list[str], **kwargs) -> list[float]:
85
  """
86
- Hard penalty if the agent's offer falls below its own BATNA.
87
 
88
- The agent plays as the SELLER. The seller's own walk-away price is
89
- batna_seller the minimum price the agent will accept. Any offer
90
- below batna_seller means the agent is capitulating below its floor.
91
 
92
  Returns -OMEGA (= -200) for capitulation, 0 otherwise.
93
- This is a hard cliff no smoothing.
94
 
95
  Args:
96
  completions: List of G=8 model outputs.
97
- **kwargs: Must contain batna_seller (float) — the seller-agent's
98
- own walk-away price (minimum acceptable price).
99
 
100
  Returns:
101
  List of float rewards: -OMEGA or 0.
102
  """
103
- batna_self = float(kwargs.get("batna_seller", 0.0))
 
 
 
 
104
  rewards = []
105
  for completion in completions:
106
  try:
107
  data = json.loads(_clean_json(completion))
108
  offer = float(data.get("offer_amount") or float("inf"))
109
- if offer < batna_self:
 
 
 
 
110
  rewards.append(-float(OMEGA))
111
- logger.debug(f"Capitulation detected: offer={offer} < batna={batna_self}")
 
 
 
112
  else:
113
  rewards.append(0.0)
114
- except (json.JSONDecodeError, ValueError):
 
115
  rewards.append(0.0)
116
  return rewards
117
 
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ # Scenarios where the AI plays as a BUYER (pushes offers DOWN).
14
+ # For these, ZOPA efficiency is measured from the buyer's side.
15
+ _BUYER_AI_SCENARIOS = frozenset({"hiring_package", "acquisition_term_sheet"})
16
+
17
 
18
  def _clean_json(text: str) -> str:
19
  """Strip markdown code fences and surrounding whitespace."""
 
22
 
23
  def negotiation_efficiency_reward(completions: list[str], **kwargs) -> list[float]:
24
  """
25
+ Primary reward: fraction of ZOPA captured by the AI agent.
26
 
27
+ For seller-AI scenarios (saas_enterprise):
28
+ E = (offer - batna_seller) / zopa_width ∈ [0, 1]
29
+ For buyer-AI scenarios (hiring_package, acquisition_term_sheet):
30
+ E = (batna_buyer - offer) / zopa_width ∈ [0, 1]
31
  Returns value in [0, GAMMA] = [0, 100].
32
 
33
  Args:
34
  completions: List of G=8 model outputs (JSON strings).
35
+ **kwargs: Must contain batna_seller (float), batna_buyer (float),
36
+ zopa_width (float), and optionally scenario_id (str).
37
 
38
  Returns:
39
  List of float rewards, same length as completions.
40
  """
41
  rewards = []
42
+ batna_seller = float(kwargs.get("batna_seller", 0))
43
+ batna_buyer = float(kwargs.get("batna_buyer", batna_seller))
44
+ zopa_width = float(kwargs.get("zopa_width", 1))
45
+ scenario_id = str(kwargs.get("scenario_id", ""))
46
+ is_buyer_ai = scenario_id in _BUYER_AI_SCENARIOS
47
+
48
  if zopa_width <= 0:
49
  zopa_width = 1.0
50
 
 
53
  data = json.loads(_clean_json(completion))
54
  offer = float(data.get("offer_amount") or 0)
55
  if offer > 0:
56
+ if is_buyer_ai:
57
+ # AI is buyer: lower offers are better; score = (buyer_batna - offer) / width
58
+ E = max(0.0, min(1.0, (batna_buyer - offer) / zopa_width))
59
+ else:
60
+ # AI is seller: higher offers are better; score = (offer - seller_batna) / width
61
+ E = max(0.0, min(1.0, (offer - batna_seller) / zopa_width))
62
  rewards.append(float(E * GAMMA))
63
  else:
64
  rewards.append(0.0)
 
75
  Uses keyword matching against persona-specific signals as a lightweight proxy.
76
  Full accuracy computed by grader.py; this is used for fast training feedback.
77
 
78
+ Signal lists are disjoint across personas (no double-counting).
79
+
80
  Args:
81
  completions: List of G=8 model outputs.
82
  **kwargs: Must contain persona (str).
 
88
  tom_signals: dict[str, list[str]] = {
89
  "shark": ["deadline", "competitor", "alternative", "pressure", "offer expires"],
90
  "diplomat": ["relationship", "partnership", "mutual", "together", "trust"],
91
+ # "trust" removed from veteran to avoid double-counting with diplomat
92
+ "veteran": ["experience", "seen this", "long-term", "patience", "seasoned"],
93
  }
94
  signals = tom_signals.get(persona.lower(), [])
95
  rewards = []
 
102
 
103
  def anti_capitulation_reward(completions: list[str], **kwargs) -> list[float]:
104
  """
105
+ Hard penalty if the agent's offer crosses its own BATNA floor.
106
 
107
+ For seller-AI: offer < batna_seller is capitulation.
108
+ For buyer-AI: offer > batna_buyer is capitulation (paying too much).
 
109
 
110
  Returns -OMEGA (= -200) for capitulation, 0 otherwise.
111
+ Parse errors are logged and treated as 0 (no false penalty for malformed output).
112
 
113
  Args:
114
  completions: List of G=8 model outputs.
115
+ **kwargs: Must contain batna_seller (float).
116
+ Optionally batna_buyer (float) and scenario_id (str).
117
 
118
  Returns:
119
  List of float rewards: -OMEGA or 0.
120
  """
121
+ batna_seller = float(kwargs.get("batna_seller", 0.0))
122
+ batna_buyer = float(kwargs.get("batna_buyer", float("inf")))
123
+ scenario_id = str(kwargs.get("scenario_id", ""))
124
+ is_buyer_ai = scenario_id in _BUYER_AI_SCENARIOS
125
+
126
  rewards = []
127
  for completion in completions:
128
  try:
129
  data = json.loads(_clean_json(completion))
130
  offer = float(data.get("offer_amount") or float("inf"))
131
+ if is_buyer_ai:
132
+ capitulated = offer > batna_buyer
133
+ else:
134
+ capitulated = offer < batna_seller
135
+ if capitulated:
136
  rewards.append(-float(OMEGA))
137
+ logger.debug(
138
+ f"Capitulation: offer={offer} {'>' if is_buyer_ai else '<'} "
139
+ f"batna={'buyer=' + str(batna_buyer) if is_buyer_ai else 'seller=' + str(batna_seller)}"
140
+ )
141
  else:
142
  rewards.append(0.0)
143
+ except (json.JSONDecodeError, ValueError) as exc:
144
+ logger.warning(f"anti_capitulation_reward parse error (treated as 0): {exc}")
145
  rewards.append(0.0)
146
  return rewards
147