sh4shv4t commited on
Commit
df724f2
·
1 Parent(s): cf5b410

Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)

Browse files
.dockerignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ __pycache__/
3
+ parlay_env/__pycache__/
4
+ *.pyc
5
+ venv/
6
+ venv-train/
7
+ node_modules/
8
+ data/*.jsonl
9
+ checkpoints/
10
+
11
+ # Keep Python source, drop generated artifacts under parlay_env/
12
+ parlay_env/
13
+ !parlay_env/**/*.py
.env.example CHANGED
@@ -1,26 +1,3 @@
1
- # API Keys — Required
2
- GOOGLE_API_KEY=AIza...
3
- HF_TOKEN=hf_...
4
-
5
- # Server ports
6
- ENV_PORT=8001
7
- DASHBOARD_PORT=8000
8
- MCP_SSE_PORT=8002
9
-
10
- # Game config
11
- MAX_TURNS_PER_EPISODE=20
12
- MIN_REWARD_THRESHOLD=-100
13
- TOP_PLAYER_THRESHOLD=0.30
14
- CREDIBILITY_POINTS_START=100
15
- CREDIBILITY_REGEN_PER_TURN=5
16
-
17
- # Training
18
- BASE_MODEL=Qwen/Qwen2.5-7B-Instruct
19
- GRPO_GENERATIONS=8
20
- GRPO_STEPS=500
21
- DATA_PATH=data/episodes.jsonl
22
- SFT_OUTPUT=models/parlay-sft
23
- GRPO_OUTPUT=models/parlay-grpo
24
-
25
- # HF Hub
26
- HF_REPO_ID=your-username/parlay-negotiator
 
1
+ GEMINI_API_KEY=your_gemini_key_here
2
+ HF_TOKEN=your_huggingface_token_here
3
+ BASE_MODEL=checkpoints/sft_1.5b/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile CHANGED
@@ -1,6 +1,10 @@
1
  FROM python:3.11-slim
2
 
3
  WORKDIR /app
 
 
 
 
4
 
5
  # Install deps first (layer-cached)
6
  COPY requirements.txt .
@@ -11,10 +15,7 @@ COPY . .
11
  # Initialise the database at build time
12
  RUN python -m scripts.init_db
13
 
14
- # startup script
15
- RUN chmod +x scripts/start.sh
16
-
17
  # HF Spaces exposes port 7860
18
  EXPOSE 7860
19
 
20
- CMD ["bash", "scripts/start.sh"]
 
1
  FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
+ ENV PYTHONUNBUFFERED=1
5
+ ARG GEMINI_API_KEY=""
6
+ ENV GEMINI_API_KEY=${GEMINI_API_KEY}
7
+ ENV GOOGLE_API_KEY=${GEMINI_API_KEY}
8
 
9
  # Install deps first (layer-cached)
10
  COPY requirements.txt .
 
15
  # Initialise the database at build time
16
  RUN python -m scripts.init_db
17
 
 
 
 
18
  # HF Spaces exposes port 7860
19
  EXPOSE 7860
20
 
21
+ CMD ["bash", "-lc", "if [ -z \"$GOOGLE_API_KEY\" ] && [ -n \"$GEMINI_API_KEY\" ]; then export GOOGLE_API_KEY=\"$GEMINI_API_KEY\"; fi; uvicorn main:app --host 0.0.0.0 --port 7860"]
Makefile CHANGED
@@ -34,7 +34,8 @@ test:
34
  venv\Scripts\pytest tests/ -v
35
 
36
  train-data:
37
- venv-train\Scripts\python -m training.generate_data --episodes 2000 --output data/episodes.jsonl
 
38
 
39
  train-sft:
40
  venv-train\Scripts\python -m training.sft_train --model Qwen/Qwen2.5-7B-Instruct --data data/episodes.jsonl --output models/parlay-sft --threshold 0.30
 
34
  venv\Scripts\pytest tests/ -v
35
 
36
  train-data:
37
+ # hackathon budget default; override with EPISODES=N
38
+ venv-train\Scripts\python -m training.generate_data --episodes 80 --output data/episodes.jsonl
39
 
40
  train-sft:
41
  venv-train\Scripts\python -m training.sft_train --model Qwen/Qwen2.5-7B-Instruct --data data/episodes.jsonl --output models/parlay-sft --threshold 0.30
README.md CHANGED
@@ -11,670 +11,158 @@ pinned: false
11
 
12
  > **The arena where AIs learn to close.**
13
 
14
- `Python 3.11` | `FastAPI` | `Gemini 2.0 Flash` | `GRPO` | `OpenEnv`
15
-
16
- ---
17
 
18
  ## Overview
19
 
20
- Parlay is a high-fidelity **reinforcement learning negotiation environment** that ships three things at once:
21
-
22
- | Audience | What they get |
23
- |---|---|
24
- | **Hackathon Judges** | A fully playable browser game, an OpenEnv-compliant WebSocket server, an MCP integration layer, and a complete GRPO training pipeline — all in one repo |
25
- | **Players** | A real-time negotiation game with five scenarios, five AI personas (Gemini-powered), Theory of Mind tracking, tactical cards, drift events, and a global leaderboard |
26
- | **B2B / Researchers** | A clean OpenEnv protocol implementation for training negotiation agents; plug in your own model, collect episodes, run GRPO fine-tuning, push to HF Hub |
27
-
28
- Parlay is built on:
29
- - **Google Gemini 2.0 Flash** — the AI counterpart, generating persona-consistent responses in real time
30
- - **FastAPI + aiosqlite** — async backend, zero ORM overhead, SQLite for portability
31
- - **OpenEnv protocol** — standard `reset/step/state` WebSocket commands for agent interoperability
32
- - **FastMCP** — universal MCP server supporting both `stdio` and SSE transports
33
- - **HF TRL GRPOTrainer** — two-stage SFT → GRPO pipeline fine-tuning Qwen2.5-7B-Instruct
34
- - **Vanilla JS + Three.js r128** — zero npm, zero build step, runs in any browser
35
-
36
- ---
37
-
38
- ## Quick Start
39
-
40
- ### Prerequisites
41
-
42
- - Python **3.11** (required; training and game stacks expect 3.11)
43
- - A Google AI Studio API key ([get one free](https://aistudio.google.com/app/apikey))
44
- - (Optional) A Hugging Face token for training and model pushing
45
-
46
- ### Windows (recommended): PowerShell from repo root
47
-
48
- All `scripts\*.ps1` files assume the **current directory is the project root**. If execution policy blocks scripts, run:
49
 
50
- `Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser`
 
 
 
 
51
 
52
- ```powershell
53
- git clone https://github.com/your-username/parlay.git
54
- cd parlay
55
- .\scripts\setup.ps1
56
- # Edit .env and set GOOGLE_API_KEY=your_key_here
57
- .\scripts\run.ps1
58
- ```
59
-
60
- Optional: `.\venv\Scripts\python -m scripts.seed_scenarios` for demo leaderboard rows. Training venv: `.\scripts\setup_train.ps1`, then `.\scripts\train_data.ps1` (and related `train_sft` / `train_grpo` / `evaluate` scripts).
61
-
62
- You can also use **GNU Make** on Windows (e.g. Git Bash or Chocolatey): `make setup`, `make run`, `make test`, `make clean`. The Makefile uses `venv\Scripts\` paths.
63
 
64
- ### Cross-platform: Docker
 
65
 
66
- Use Docker when you want the same flow on every OS (no local venv):
67
-
68
- ```bash
69
- cp .env.example .env
70
- # Set GOOGLE_API_KEY in .env
71
- docker compose up --build
72
- ```
73
-
74
- See the [Docker](#docker) section for service URLs and training images.
75
-
76
- ### Open in browser (local server)
77
-
78
- ```
79
- http://localhost:8000 # Game dashboard
80
- http://localhost:8000/train # Training dashboard
81
- http://localhost:8000/docs # Interactive API docs (Swagger)
82
- ```
83
 
84
- ### macOS / Linux (manual venv)
85
 
86
  ```bash
87
- git clone https://github.com/your-username/parlay.git
88
- cd parlay
89
- python3.11 -m venv venv
90
- source venv/bin/activate
91
  pip install -r requirements.txt
92
- cp .env.example .env
93
- python -m scripts.init_db
94
- uvicorn main:app --host 0.0.0.0 --port 8000 --reload
95
  ```
96
 
97
- ---
98
-
99
- ## API Keys
100
-
101
- | Key | Required | Where to get it | Used for |
102
- |---|---|---|---|
103
- | `GOOGLE_API_KEY` | **Yes** | [Google AI Studio](https://aistudio.google.com/app/apikey) | Gemini 2.0 Flash (game AI + data gen) |
104
- | `HF_TOKEN` | Only for training | [Hugging Face Settings](https://huggingface.co/settings/tokens) | Model push to HF Hub |
105
-
106
- Set these in your `.env` file (never commit `.env`):
107
-
108
- ```bash
109
- GOOGLE_API_KEY=AIzaSy...
110
- HF_TOKEN=hf_...
111
- ```
112
-
113
- ---
114
-
115
- ## Project Structure
116
-
117
- ```
118
- parlay/
119
- ├── main.py # FastAPI app entry point
120
-
121
- ├── parlay_env/ # Core RL environment (OpenEnv-compliant)
122
- │ ├── __init__.py
123
- │ ├── server.py # WebSocket router (reset/step/state endpoints)
124
- │ ├── env.py # ParlayEnv class — OpenEnv implementation
125
- │ ├── models.py # Pydantic models: ParlayState, ParlayAction, BeliefState…
126
- │ ├── reward.py # Reward coefficients (ALPHA, BETA, GAMMA, OMEGA…)
127
- │ ├── grader.py # Pure reward functions: step reward, terminal reward
128
- │ ├── game_theory.py # ZOPA, Nash, Pareto, Shapley, Rubinstein computations
129
- │ └── exceptions.py # Custom exceptions (InvalidScenarioError, CapitulationError…)
130
-
131
- ├── game/ # Game logic layer
132
- │ ├── __init__.py
133
- │ ├── scenarios.py # 5 negotiation scenarios with drift events
134
- │ ├── personas.py # 5 AI personas with Gemini prompt templates
135
- │ └── session.py # Session management: active games, turn routing
136
-
137
- ├── agent/ # AI agent components
138
- │ ├── __init__.py
139
- │ ├── gemini_client.py # Gemini 2.0 Flash async wrapper
140
- │ ├── tom_tracker.py # Theory of Mind belief tracker
141
- │ └── tactical.py # Tactical card execution logic
142
-
143
- ├── dashboard/ # Frontend (zero npm, zero build)
144
- │ ├── index.html # Main game UI
145
- │ ├── train.html # Training monitor UI
146
- │ ├── api.py # FastAPI router for dashboard REST endpoints
147
- │ └── static/
148
- │ ├── app.js # Game WebSocket client + UI logic
149
- │ ├── character.js # Three.js r128 animated persona character
150
- │ ├── chart_utils.js # Chart.js reward visualization helpers
151
- │ └── style.css # CSS with --parlay-* custom properties
152
-
153
- ├── mcp_server/ # MCP integration (stdio + SSE)
154
- │ ├── __init__.py
155
- │ └── server.py # FastMCP tools: negotiate, get_state, get_leaderboard…
156
-
157
- ├── training/ # Isolated training pipeline (never imported by game)
158
- │ ├── __init__.py
159
- │ ├── generate_data.py # Gemini self-play episode generation
160
- │ ├── sft_train.py # SFTTrainer fine-tuning on top episodes
161
- │ ├── grpo_train.py # GRPOTrainer RL fine-tuning
162
- │ ├── reward_fn.py # GRPO reward functions (wraps grader.py)
163
- │ ├── evaluate.py # Three-bar comparison chart: base vs SFT vs GRPO
164
- │ └── push_to_hub.py # Upload model to HF Hub
165
-
166
- ├── scripts/
167
- │ ├── __init__.py
168
- │ ├── init_db.py # Create SQLite schema (idempotent)
169
- │ ├── seed_scenarios.py # Insert demo leaderboard entries
170
- │ ├── setup.ps1 # Windows: game venv + .env + DB init
171
- │ ├── setup_train.ps1 # Windows: training venv (PyTorch, TRL, …)
172
- │ ├── run.ps1 / run_env.ps1 / run_mcp.ps1
173
- │ ├── train_*.ps1 / evaluate.ps1 / test.ps1
174
-
175
- ├── tests/
176
- │ ├── __init__.py
177
- │ ├── test_grader.py # Reward computation tests
178
- │ ├── test_game_theory.py # ZOPA/Nash/Pareto/Shapley tests
179
- │ ├── test_tom.py # Theory of Mind tracker tests
180
- │ ├── test_reward.py # Reward constants tests
181
- │ └── test_scenarios.py # Scenario definition tests
182
-
183
- ├── data/ # Generated episode JSONL files (gitignored)
184
- ├── models/ # Fine-tuned model checkpoints (gitignored)
185
- ├── results/ # Evaluation charts and metrics (gitignored)
186
-
187
- ├── requirements.txt # Core dependencies
188
- ├── requirements-train.txt # Training-only dependencies (torch, trl, peft…)
189
- ├── Makefile # Windows-oriented targets (venv\\Scripts\\ paths)
190
- ├── .gitattributes # LF for source; CRLF for .ps1
191
- ├── .env.example # Environment variable template
192
- ├── .gitignore
193
- ├── docker-compose.yml # Multi-service Docker deployment
194
- ├── Dockerfile.game # Game + dashboard service
195
- ├── Dockerfile.env # OpenEnv WebSocket service
196
- └── Dockerfile.train # GRPO training service (CUDA)
197
- ```
198
-
199
- ---
200
-
201
- ## Game Guide
202
-
203
- ### How to Play
204
 
205
- 1. **Choose a scenario** — five high-stakes deal types, each with unique ZOPA ranges and drift events
206
- 2. **Choose your persona style** — affects how aggressively the AI counterpart responds
207
- 3. **Negotiate in natural language** — type your offers and arguments in the chat
208
- 4. **Use tactical cards** — spend Credibility Points to play power moves (anchor, BATNA reveal, deadline pressure)
209
- 5. **Watch for drift events** — the AI's hidden priorities shift mid-negotiation; adapt or lose ground
210
- 6. **Close within 20 turns** — speed bonuses reward efficient closers
211
 
212
- ### Key Concepts
213
 
214
- | Term | Definition |
215
- |---|---|
216
- | **ZOPA** | Zone of Possible Agreement — the range between both parties' walk-away prices where a deal is mutually beneficial |
217
- | **BATNA** | Best Alternative to a Negotiated Agreement — your outside option; the floor below which you'd rather walk away |
218
- | **Nash Bargaining Solution** | The game-theoretically "fair" split of the ZOPA surplus — the midpoint of both BATNAs |
219
- | **Anchor** | Your opening offer. The higher you anchor (as seller), the more the counterpart adjusts from that reference point |
220
- | **Rubinstein Deadline** | The advantage of having more time — patient negotiators extract better deals |
221
- | **Capitulation Cliff** | Accepting below your BATNA triggers a hard -150 penalty (OMEGA). Never capitulate |
222
- | **Theory of Mind** | Parlay tracks the AI's inferred beliefs about you — high ToM accuracy gives a step reward bonus |
223
- | **Drift Event** | A mid-game shock (budget cut, competitor offer, urgency spike) that changes the AI's hidden priorities |
224
 
225
- ### Tactical Cards
226
 
227
- | Card | CP Cost | Effect |
228
- |---|---|---|
229
- | **Anchor High** | 10 CP | Lock in a high reference price — reduces AI's willingness to counter aggressively |
230
- | **BATNA Reveal** | 15 CP | Signal your outside option — increases AI urgency if credible |
231
- | **Deadline Pressure** | 20 CP | Introduce artificial urgency — accelerates AI concessions by 15% |
232
- | **Bundle Offer** | 12 CP | Add non-monetary value — expands the ZOPA by shifting AI utility |
233
- | **Silent Close** | 25 CP | Make a final offer with no further negotiation signal — high risk, high reward |
234
- | **Coalition Play** | 30 CP | Invoke Act 3 coalition mechanics — brings in a third party for multi-issue negotiation |
235
 
236
- ### Scoring
237
 
238
- Your final score is computed by the Parlay Grader:
239
 
240
- ```
241
- Final Score = Terminal Reward + Cumulative Step Rewards
242
- ```
243
-
244
- Deal Efficiency is displayed as a percentage:
245
- ```
246
- Deal Efficiency = (Final Price - Seller BATNA) / (Buyer BATNA - Seller BATNA)
247
- ```
248
-
249
- A deal efficiency of 1.0 means you captured the full ZOPA surplus. 0.5 is the Nash fair split.
250
-
251
- ---
252
-
253
- ## OpenEnv Protocol
254
-
255
- Parlay implements the OpenEnv standard for RL environments over WebSocket.
256
-
257
- ### Connection
258
-
259
- ```
260
- ws://localhost:8000/env/ws/{session_id}
261
- ```
262
-
263
- ### Commands
264
-
265
- #### `reset` — Start a new episode
266
-
267
- ```json
268
- {
269
- "command": "reset",
270
- "scenario_id": "saas_enterprise",
271
- "persona": "shark",
272
- "player_name": "MyAgent"
273
- }
274
- ```
275
-
276
- Response:
277
- ```json
278
- {
279
- "type": "observation",
280
- "session_id": "abc-123",
281
- "scenario_id": "saas_enterprise",
282
- "persona": "shark",
283
- "act": 1,
284
- "step_count": 0,
285
- "belief": {
286
- "est_budget": 140000,
287
- "est_walk_away": 125000,
288
- "est_urgency": 0.4,
289
- "est_has_alternative": false,
290
- "confidence": 0.3
291
- },
292
- "credibility_points": 100,
293
- "offer_history": [],
294
- "episode_done": false
295
- }
296
- ```
297
-
298
- #### `step` — Take a negotiation turn
299
-
300
- ```json
301
- {
302
- "command": "step",
303
- "utterance": "I propose an annual contract at $155,000 with a 90-day payment term.",
304
- "offer_amount": 155000,
305
- "tactical_move": null
306
- }
307
- ```
308
-
309
- Response:
310
- ```json
311
- {
312
- "type": "step_result",
313
- "step_reward": 12.4,
314
- "cumulative_reward": 12.4,
315
- "ai_response": "That's ambitious. Our budget ceiling won't stretch that far...",
316
- "belief": { "...updated belief state..." },
317
- "tension_score": 0.6,
318
- "drift_fired": false,
319
- "episode_done": false
320
- }
321
- ```
322
-
323
- #### `state` — Query current state without acting
324
-
325
- ```json
326
- { "command": "state" }
327
- ```
328
-
329
- #### `close` — Accept final price and end episode
330
-
331
- ```json
332
- {
333
- "command": "close",
334
- "final_price": 148000
335
- }
336
- ```
337
-
338
- Response:
339
- ```json
340
- {
341
- "type": "episode_done",
342
- "total_reward": 287.4,
343
- "deal_efficiency": 0.82,
344
- "acts_completed": 2,
345
- "bluffs_caught": 1,
346
- "drift_adapted": true,
347
- "deal_closed": true,
348
- "leaderboard_rank": 3
349
- }
350
- ```
351
-
352
- ### HTTP REST Endpoints
353
-
354
- | Method | Path | Description |
355
- |---|---|---|
356
- | `GET` | `/health` | Health check |
357
- | `GET` | `/env/scenarios` | List all scenarios |
358
- | `GET` | `/env/personas` | List all personas |
359
- | `GET` | `/dashboard/leaderboard` | Global leaderboard |
360
- | `GET` | `/dashboard/leaderboard/{scenario_id}` | Per-scenario leaderboard |
361
- | `POST` | `/dashboard/submit` | Submit episode result |
362
- | `GET` | `/docs` | Swagger UI |
363
-
364
- ---
365
-
366
- ## MCP Setup
367
-
368
- Parlay ships a universal MCP server supporting **both** `stdio` and SSE transports.
369
-
370
- ### Available MCP Tools
371
-
372
- | Tool | Description |
373
- |---|---|
374
- | `negotiate` | Send a negotiation message and get the AI's response |
375
- | `get_state` | Retrieve current session state and belief model |
376
- | `reset_session` | Start a new negotiation session |
377
- | `close_deal` | Accept a final price and get episode grade |
378
- | `get_leaderboard` | Fetch top performers globally or by scenario |
379
- | `list_scenarios` | Get all available scenarios with ZOPA ranges |
380
- | `list_personas` | Get all personas with strategy profiles |
381
- | `get_game_theory` | Compute ZOPA, Nash point, Rubinstein advantage for any deal |
382
-
383
- ### Client 1: Claude Desktop / Claude Code
384
-
385
- Add to your `claude_desktop_config.json` (usually at `~/Library/Application Support/Claude/claude_desktop_config.json` on macOS):
386
-
387
- ```json
388
- {
389
- "mcpServers": {
390
- "parlay": {
391
- "command": "python",
392
- "args": ["-m", "mcp_server.server", "stdio"],
393
- "cwd": "/path/to/parlay",
394
- "env": {
395
- "GOOGLE_API_KEY": "your_key_here"
396
- }
397
- }
398
- }
399
- }
400
- ```
401
-
402
- ### Client 2: Continue.dev / Zed / Any SSE Client
403
-
404
- First start the SSE server:
405
-
406
- ```bash
407
- python -m mcp_server.server sse
408
- # Listening on http://localhost:8002/sse
409
- ```
410
-
411
- Then configure your client to point at:
412
-
413
- ```
414
- http://localhost:8002/sse
415
- ```
416
-
417
- In `Continue.dev` (`~/.continue/config.json`):
418
-
419
- ```json
420
- {
421
- "experimental": {
422
- "modelContextProtocolServers": [
423
- {
424
- "transport": {
425
- "type": "sse",
426
- "url": "http://localhost:8002/sse"
427
- }
428
- }
429
- ]
430
- }
431
- }
432
- ```
433
-
434
- ### Client 3: Generic stdio (any MCP-compatible agent)
435
-
436
- ```bash
437
- python -m mcp_server.server stdio
438
- ```
439
-
440
- Pipe JSON-RPC messages to stdin; responses arrive on stdout. Compatible with any MCP client library.
441
-
442
- ---
443
 
444
  ## Training Pipeline
445
 
446
- Parlay uses a two-stage pipeline: **SFT warmup → GRPO fine-tuning**. Never skip the SFT stage — GRPO reward curves are noisy without a warm-started model.
447
-
448
- ### Stage 1: Generate Self-Play Episodes
449
-
450
- Uses Gemini 2.0 Flash to simulate full negotiation episodes across all persona × scenario combinations.
451
-
452
- ```bash
453
- python -m training.generate_data --episodes 2000 --output data/episodes.jsonl
454
- ```
455
-
456
- Diversity guarantees enforced:
457
- - Minimum 20 episodes per (persona × scenario) pair = 500 baseline
458
- - 30% noise injection for exploration
459
- - 40% forced drift event rate
460
- - 25% Act 3 coalition scenarios
461
-
462
- Each episode record:
463
- ```json
464
- {
465
- "prompt": "You are negotiating a SaaS enterprise deal...",
466
- "conversation": [...],
467
- "reward": 247.3,
468
- "deal_efficiency": 0.79,
469
- "persona": "shark",
470
- "scenario_id": "saas_enterprise",
471
- "acts_completed": 2,
472
- "tom_accuracy": 0.81,
473
- "drift_adapted": true,
474
- "split": "train"
475
- }
476
  ```
477
 
478
- ### Stage 2: SFT Fine-Tuning
479
-
480
- Train Qwen2.5-7B-Instruct on the top 60% of episodes by reward:
481
 
482
  ```bash
483
- python -m training.sft_train \
484
- --model Qwen/Qwen2.5-7B-Instruct \
485
- --data data/episodes.jsonl \
486
- --output models/parlay-sft \
487
- --threshold 0.60
488
  ```
489
 
490
- Uses LoRA (r=16, alpha=32) on `q_proj` and `v_proj`. Full fine-tuning is never used.
491
-
492
- ### Stage 3: GRPO Fine-Tuning
493
-
494
- Apply Group Relative Policy Optimization with G=8 generations per prompt:
495
 
496
  ```bash
497
- python -m training.grpo_train \
498
- --model models/parlay-sft \
499
- --data data/episodes.jsonl \
500
- --output models/parlay-grpo \
501
- --steps 500
502
  ```
503
 
504
- GRPO hyperparameters:
505
- - `num_generations=8` (G=8 per prompt)
506
- - `beta=0.001` (low KL coefficient — allows exploration)
507
- - `epsilon=0.2` (clipping range)
508
- - `scale_rewards="batch"` (batch-level reward standardization)
509
- - `learning_rate=5e-7`
510
-
511
- ### Stage 4: Evaluate
512
 
513
  ```bash
514
- python -m training.evaluate \
515
- --base Qwen/Qwen2.5-7B-Instruct \
516
- --sft models/parlay-sft \
517
- --grpo models/parlay-grpo \
518
- --data data/episodes.jsonl \
519
- --output results/eval_results.json
520
  ```
521
 
522
- Produces a three-bar comparison chart: **Base vs SFT vs GRPO** across mean reward, deal efficiency, and bluff detection rate.
523
 
524
- ### Stage 5: Push to Hub
525
 
526
- ```bash
527
- python -m training.push_to_hub \
528
- --model models/parlay-grpo \
529
- --repo your-username/parlay-negotiator
530
- ```
531
-
532
- Requires `HF_TOKEN` and `HF_REPO_ID` in `.env`.
533
-
534
- ---
535
 
536
- ## Personas
537
 
538
- Five AI negotiation personas, each powered by a distinct Gemini 2.0 Flash system prompt:
539
 
540
- | Persona | Aggression | Patience | Bluff Rate | Strategy |
541
- |---|---|---|---|---|
542
- | **Shark** | 0.90 | 0.20 | 0.45 | Opens high, concedes slowly, uses deadline pressure, willing to walk away |
543
- | **Diplomat** | 0.30 | 0.80 | 0.10 | Relationship-focused, seeks mutual gain, rarely bluffs, prefers bundle deals |
544
- | **Analyst** | 0.50 | 0.70 | 0.15 | Data-driven, requests justification for every number, ZOPA-aware, systematic |
545
- | **Veteran** | 0.65 | 0.85 | 0.30 | Pattern-recognizes anchors, absorbs pressure, uses silence as a tool |
546
- | **Wildcard** | 0.75 | 0.35 | 0.55 | Unpredictable, drift-prone, high bluff rate, can pivot strategy mid-negotiation |
547
 
548
- Persona drift events can cause a **Wildcard** to briefly adopt **Shark** tactics, or a **Diplomat** to reveal an unexpected BATNA. Adapt or get caught off guard.
549
 
550
- ---
551
 
552
- ## Scenarios
553
 
554
- Five negotiation scenarios spanning B2B deal archetypes:
555
 
556
- | Scenario ID | Title | ZOPA Range | Complexity | Drift Events |
557
- |---|---|---|---|---|
558
- | `saas_enterprise` | Enterprise SaaS Annual License | $125K – $165K | Medium | Budget cut at turn 7 |
559
- | `consulting_retainer` | Consulting Retainer Contract | $8K – $15K/mo | Medium | Competitor reveal at turn 5 |
560
- | `hiring_package` | Senior Engineering Hire Package | $180K – $240K | Low | Competing offer at turn 6 |
561
- | `vendor_hardware` | Hardware Vendor Bulk Purchase | $2.1M – $3.4M | High | Supply chain shock at turn 8 |
562
- | `acquisition_term_sheet` | Startup Acquisition Term Sheet | $8.5M – $16M | Very High | Board veto threat at turn 10, valuation dispute at turn 14 |
 
563
 
564
- Each scenario defines:
565
- - `batna_buyer`: Buyer's walk-away ceiling
566
- - `batna_seller`: Seller's walk-away floor
567
- - `anchor_buyer`: Typical buyer opening offer
568
- - `anchor_seller`: Typical seller opening ask
569
- - `drift_events`: List of mid-game shocks with trigger turns and effects
570
- - `currency`: Always USD
571
- - `difficulty`: `low | medium | high | very_high`
572
 
573
- ---
574
 
575
- ## Reward Function
576
-
577
- The Parlay grader computes rewards in two phases:
578
-
579
- ### Step Reward (per turn)
580
-
581
- ```
582
- r_step = α · ΔZOPA_position
583
- + β · ToM_accuracy_improvement
584
- - δ · concession_magnitude
585
- - θ · noise_penalty
586
- + ε · tactical_card_bonus
587
  ```
588
 
589
- Where:
590
- - **α (ALPHA = 2.0)** — reward for improving your ZOPA position
591
- - **β (BETA = 5.0)** — reward for improving ToM belief accuracy
592
- - **δ (DELTA = 1.5)** — penalty per unit of concession from previous offer
593
- - **θ (THETA = 3.0)** — penalty for low-grounding utterances (noise)
594
- - **ε (EPSILON = 8.0)** — bonus for successful tactical card execution
595
 
596
- ### Terminal Reward (episode end)
597
-
598
- ```
599
- r_terminal =
600
- if final_price < batna_seller: -Ω (capitulation cliff: -150)
601
- elif deal_closed:
602
- Γ (base close bonus: +100)
603
- + ζ · deal_efficiency (ZOPA capture: up to +50)
604
- + η · acts_completed (multi-act bonus: +10/act)
605
- + Γ · (1 - t_close/t_max) (speed bonus: up to +100)
606
- + ETA · drift_adapted (drift adaptation: +10)
607
- else (no deal):
608
- -Γ/2 + β · avg_tom_accuracy (partial credit)
609
  ```
610
 
611
- Where:
612
- - **Γ (GAMMA = 100.0)** — primary close bonus
613
- - **ζ (ZETA = 50.0)** — ZOPA efficiency multiplier
614
- - **η (ETA = 10.0)** — per-act completion bonus (max 3 acts = +30)
615
- - **Ω (OMEGA = 150.0)** — capitulation cliff penalty
616
- - **t_close** — turn at which deal was closed
617
- - **t_max** — maximum turns (default: 20)
618
-
619
- All coefficients live exclusively in `parlay_env/reward.py`. Never hardcode them elsewhere.
620
-
621
- ---
622
-
623
- ## Docker
624
-
625
- ### Run all services
626
 
627
  ```bash
628
- cp .env.example .env
629
- # Set GOOGLE_API_KEY in .env
630
-
631
- docker compose up --build
632
  ```
633
 
634
- Services:
635
- - `game` → `http://localhost:8000` — game dashboard + API
636
- - `env` → `http://localhost:8001` — OpenEnv WebSocket server
637
- - `mcp` → `http://localhost:8002` — MCP SSE server
638
-
639
- ### Run training (requires GPU)
640
 
641
  ```bash
642
- docker build -f Dockerfile.train -t parlay-train .
643
- docker run --gpus all -v $(pwd)/data:/app/data -v $(pwd)/models:/app/models \
644
- -e GOOGLE_API_KEY=$GOOGLE_API_KEY \
645
- -e HF_TOKEN=$HF_TOKEN \
646
- parlay-train python -m training.grpo_train --steps 500
647
  ```
648
 
649
- ### Individual services
650
 
651
  ```bash
652
- # Game only
653
- docker build -f Dockerfile.game -t parlay-game .
654
- docker run -p 8000:8000 -e GOOGLE_API_KEY=$GOOGLE_API_KEY parlay-game
655
-
656
- # OpenEnv only
657
- docker build -f Dockerfile.env -t parlay-env .
658
- docker run -p 8001:8001 -e GOOGLE_API_KEY=$GOOGLE_API_KEY parlay-env
659
  ```
660
 
661
- ---
662
-
663
  ## Testing
664
 
665
- ### Run the full test suite
666
 
667
  ```bash
668
  pytest tests/ -v
669
  ```
670
 
671
- ### Run with coverage
672
-
673
- ```bash
674
- pytest tests/ -v --tb=short --cov=parlay_env --cov=game --cov=agent --cov-report=term-missing
675
- ```
676
-
677
- ### Run a specific test module
678
 
679
  ```bash
680
  pytest tests/test_grader.py -v
@@ -684,122 +172,28 @@ pytest tests/test_reward.py -v
684
  pytest tests/test_scenarios.py -v
685
  ```
686
 
687
- ### Test descriptions
688
-
689
- | File | What it tests |
690
- |---|---|
691
- | `test_grader.py` | Step reward, terminal reward, episode grade computation |
692
- | `test_game_theory.py` | ZOPA, Nash bargaining, Pareto frontier, Shapley value, anchoring, Rubinstein |
693
- | `test_tom.py` | Theory of Mind tracker: belief updates, bluff detection, drift events, accuracy |
694
- | `test_reward.py` | Reward coefficient constants and their mathematical constraints |
695
- | `test_scenarios.py` | Scenario definitions: ZOPA validity, drift events, currency, IDs |
696
-
697
- All tests follow the pattern: `Test{Module}` class → `test_{scenario}` methods → `assert ... f"Expected {expected}, got {result}"`.
698
-
699
- ---
700
-
701
- ## Architecture Decisions
702
-
703
- ### Why SQLite over Postgres?
704
-
705
- Parlay is designed to be a **zero-infrastructure hackathon demo**. SQLite with `aiosqlite` provides full async support, requires no Docker service for the database, and the `parlay.db` file can be committed for demo snapshots. Migrating to Postgres requires only changing the connection string.
706
-
707
- ### Why Vanilla JS over React/Vue?
708
-
709
- The `.cursorrules` mandate: zero npm, zero build step. Three.js r128 from cdnjs gives us 3D animated personas. Chart.js 4.4 gives us reward curves. `fetch()` + `WebSocket` gives us real-time game state. The entire frontend loads from a single HTML file with `<script>` tags. This means anyone can open the dashboard without `node_modules`.
710
-
711
- ### Why GRPO over PPO?
712
 
713
- GRPO (Group Relative Policy Optimization) eliminates the need for a separate critic/value model. With G=8 generations per prompt, GRPO uses within-group reward standardization as its baseline — simpler, more stable, and better suited to the sparse reward structure of negotiation episodes.
 
 
 
 
714
 
715
- ### Why Gemini 2.0 Flash?
716
 
717
- - Free tier available via Google AI Studio (critical for hackathon accessibility)
718
- - Sub-500ms latency for negotiation turns with `max_output_tokens=500`
719
- - Strong instruction-following for persona-consistent responses
720
- - Async-compatible via `run_in_executor` pattern
721
-
722
- ---
723
-
724
- ## Environment Variables Reference
725
-
726
- | Variable | Default | Description |
727
- |---|---|---|
728
- | `GOOGLE_API_KEY` | — | **Required.** Google AI Studio API key |
729
- | `HF_TOKEN` | — | Hugging Face token (training only) |
730
- | `ENV_PORT` | `8001` | OpenEnv WebSocket server port |
731
- | `DASHBOARD_PORT` | `8000` | Dashboard + game server port |
732
- | `MCP_SSE_PORT` | `8002` | MCP SSE server port |
733
- | `MAX_TURNS_PER_EPISODE` | `20` | Maximum turns before episode ends |
734
- | `MIN_REWARD_THRESHOLD` | `-100` | Minimum reward for SFT data inclusion |
735
- | `TOP_PLAYER_THRESHOLD` | `0.60` | Percentile cutoff for SFT training data |
736
- | `CREDIBILITY_POINTS_START` | `100` | Starting CP for tactical cards |
737
- | `CREDIBILITY_REGEN_PER_TURN` | `5` | CP regenerated each turn |
738
- | `BASE_MODEL` | `Qwen/Qwen2.5-7B-Instruct` | HF model ID for training base |
739
- | `GRPO_GENERATIONS` | `8` | G value for GRPO (generations per prompt) |
740
- | `GRPO_STEPS` | `500` | GRPO training steps |
741
- | `DATA_PATH` | `data/episodes.jsonl` | Episode data for training |
742
- | `SFT_OUTPUT` | `models/parlay-sft` | SFT checkpoint output path |
743
- | `GRPO_OUTPUT` | `models/parlay-grpo` | GRPO checkpoint output path |
744
- | `HF_REPO_ID` | — | HF Hub repo for model push |
745
-
746
- ---
747
-
748
- ## Testing Without API Keys
749
-
750
- Everything in Parlay runs in **mock mode** when `GOOGLE_API_KEY` is not set.
751
- Mock mode returns canned persona-consistent responses so you can play and test
752
- the full game loop without any external account.
753
 
754
  ```bash
755
- # 1. Set up the game environment
756
- make setup
757
-
758
- # 2. Run the keyless test suite (zero API calls)
759
- make test-keyless
760
-
761
- # 3. Start the server in mock mode
762
- make run
763
-
764
- # 4. Open the game in your browser
765
- # → http://localhost:8000
766
- # A "Demo mode" banner confirms mock mode is active.
767
  ```
768
 
769
- To switch to real AI: add `GOOGLE_API_KEY=your_key` to `.env` and restart.
770
-
771
- To test the exact HF Spaces container locally before pushing:
772
 
773
  ```bash
774
- make docker-test
775
- # → http://localhost:7860
776
  ```
777
 
778
- ---
779
-
780
- ## Contributing
781
-
782
- 1. Fork the repo and create a feature branch
783
- 2. Follow the module dependency graph: `training/ → parlay_env/ → game/ → agent/`
784
- 3. Add type hints and docstrings to all public functions
785
- 4. Write at least 2 tests per new function (happy path + edge case)
786
- 5. Run `pytest tests/` — all tests must pass
787
- 6. Verify `docker compose up --build` completes without errors
788
-
789
- ---
790
-
791
  ## License
792
 
793
- MIT License
794
-
795
- Copyright (c) 2026 Parlay Contributors
796
-
797
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
798
-
799
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
800
-
801
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
802
-
803
- ---
804
-
805
- *Built for the Meta Hackathon 2026. Powered by Gemini 2.0 Flash + Qwen2.5-7B.*
 
11
 
12
  > **The arena where AIs learn to close.**
13
 
14
+ `Python 3.11` | `FastAPI` | `Gemini 2.5 Flash` | `GRPO` | `OpenEnv-style WS`
 
 
15
 
16
  ## Overview
17
 
18
+ Parlay is a negotiation RL environment + browser game + training stack:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ - Three negotiation scenarios and three personas.
21
+ - OpenEnv-style WebSocket interface (`reset` / `step` / `state`) on `/env/ws`.
22
+ - Theory-of-Mind belief tracking with dense reward shaping.
23
+ - Dynamic ZOPA erosion under sustained tension.
24
+ - Training pipeline from Gemini self-play data to SFT and GRPO.
25
 
26
+ Gemini model routing:
 
 
 
 
 
 
 
 
 
 
27
 
28
+ - `gemini-2.5-flash-lite` for data generation and self-play.
29
+ - `gemini-2.5-flash` for demo gameplay and MCP tools.
30
 
31
+ ## Quickstart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ Run exactly in this order:
34
 
35
  ```bash
 
 
 
 
36
  pip install -r requirements.txt
37
+ export GEMINI_API_KEY=your_key
38
+ uvicorn main:app --port 8000
39
+ open http://localhost:8000
40
  ```
41
 
42
+ ## Reward Design
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ Per-step reward:
 
 
 
 
 
45
 
46
+ `R_t = α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV`
47
 
48
+ Terminal reward:
 
 
 
 
 
 
 
 
 
49
 
50
+ `R_T = γ·E + ε·S + ζ·D`
51
 
52
+ Capitulation floor:
 
 
 
 
 
 
 
53
 
54
+ `R_T = -ω` when final deal breaches BATNA.
55
 
56
+ Constants (from `parlay_env/reward.py`):
57
 
58
+ - `ALPHA=2`, `BETA=5`, `DELTA=3`, `THETA=10`
59
+ - `PSI=12`, `MU=8`
60
+ - `GAMMA=100`, `EPSILON=20`, `ZETA=15`, `OMEGA=200`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  ## Training Pipeline
63
 
64
+ ```text
65
+ Gemini self-play (training/generate_data.py)
66
+ |
67
+ v
68
+ SFT warm start (training/sft_train.py)
69
+ |
70
+ v
71
+ GRPO fine-tune (training/grpo_train.py)
72
+ |
73
+ v
74
+ Evaluation + comparison (training/evaluate.py, scripts/eval_comparison.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ```
76
 
77
+ ### Data generation
 
 
78
 
79
  ```bash
80
+ python -m training.generate_data --episodes 80 --output data/episodes.jsonl
 
 
 
 
81
  ```
82
 
83
+ ### SFT
 
 
 
 
84
 
85
  ```bash
86
+ python -m training.sft_train --data data/episodes.jsonl --output checkpoints/sft_1.5b/
 
 
 
 
87
  ```
88
 
89
+ ### GRPO
 
 
 
 
 
 
 
90
 
91
  ```bash
92
+ BASE_MODEL=checkpoints/sft_1.5b/ python -m training.grpo_train --data data/episodes.jsonl --output models/parlay-grpo
 
 
 
 
 
93
  ```
94
 
95
+ ## Baseline vs GRPO Results
96
 
97
+ [Run scripts/eval_comparison.py after training to populate this section]
98
 
99
+ `results/comparison.png`
 
 
 
 
 
 
 
 
100
 
101
+ ## HuggingFace Space
102
 
103
+ [Space URL here]
104
 
105
+ ## OpenEnv
 
 
 
 
 
 
106
 
107
+ See `openenv.yaml` for environment manifest metadata and reward spec.
108
 
109
+ WebSocket endpoint:
110
 
111
+ `ws://<host>:<port>/env/ws`
112
 
113
+ ## Architecture
114
 
115
+ - `main.py`: FastAPI entry, routers, static files.
116
+ - `parlay_env/`: server, models, grader, reward constants, game theory.
117
+ - `agent/`: Gemini client, ToM tracker, self-play runner.
118
+ - `game/`: scenarios, tactical cards, leaderboard.
119
+ - `dashboard/`: UI and API routes, spectator stream.
120
+ - `training/`: dataset generation, SFT, GRPO, evaluation.
121
+ - `mcp_server/`: FastMCP tools.
122
+ - `tests/`: keyless and module tests.
123
 
124
+ ## Runbook
 
 
 
 
 
 
 
125
 
126
+ ### Local app
127
 
128
+ ```bash
129
+ uvicorn main:app --host 0.0.0.0 --port 8000
 
 
 
 
 
 
 
 
 
 
130
  ```
131
 
132
+ ### OpenEnv server only
 
 
 
 
 
133
 
134
+ ```bash
135
+ python -m parlay_env.server --port 8001
 
 
 
 
 
 
 
 
 
 
 
136
  ```
137
 
138
+ ### Keyless test suite
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  ```bash
141
+ pytest tests/test_keyless.py -v
 
 
 
142
  ```
143
 
144
+ ### Smoke test
 
 
 
 
 
145
 
146
  ```bash
147
+ python smoke_test.py
 
 
 
 
148
  ```
149
 
150
+ ### Docker
151
 
152
  ```bash
153
+ docker build -t parlay .
154
+ docker run -p 7860:7860 -e GEMINI_API_KEY=$GEMINI_API_KEY parlay
 
 
 
 
 
155
  ```
156
 
 
 
157
  ## Testing
158
 
159
+ ### Full suite
160
 
161
  ```bash
162
  pytest tests/ -v
163
  ```
164
 
165
+ ### Focused modules
 
 
 
 
 
 
166
 
167
  ```bash
168
  pytest tests/test_grader.py -v
 
172
  pytest tests/test_scenarios.py -v
173
  ```
174
 
175
+ ### What tests cover
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ - `test_keyless.py`: no-key full stack sanity checks.
178
+ - `test_grader.py`: step/terminal reward behavior.
179
+ - `test_game_theory.py`: ZOPA/Nash/Pareto/Shapley.
180
+ - `test_tom.py`: ToM updates and belief metrics.
181
+ - `test_training_pipeline.py`: training data/plumbing checks.
182
 
183
+ ## MCP
184
 
185
+ Run MCP server:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  ```bash
188
+ python -m mcp_server.server stdio
 
 
 
 
 
 
 
 
 
 
 
189
  ```
190
 
191
+ or SSE:
 
 
192
 
193
  ```bash
194
+ python -m mcp_server.server sse
 
195
  ```
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  ## License
198
 
199
+ MIT
 
 
 
 
 
 
 
 
 
 
 
 
README_SPACES.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Parlay ◈
3
+ emoji: 🤝
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ ---
10
+
11
+ Parlay is a negotiation RL environment and playable browser game where agents bargain under hidden information.
12
+ It combines Theory-of-Mind belief tracking, dynamic ZOPA erosion under sustained tension, and tactical negotiation moves.
13
+ This Space exposes the live game UI and OpenEnv-style WebSocket flow so you can test policies interactively.
14
+ Use the spectator view to inspect hidden state dynamics during a live negotiation demo.
agent/gemini_client.py CHANGED
@@ -310,7 +310,11 @@ async def call_gemini(
310
 
311
  mid = model if model is not None else MODEL_ID_DATA
312
  text = ""
313
- from google.genai import types # noqa: PLC0415
 
 
 
 
314
 
315
  history = messages[:-1] if len(messages) > 1 else []
316
  last_msg = messages[-1]["parts"][0] if messages else "Begin the negotiation."
 
310
 
311
  mid = model if model is not None else MODEL_ID_DATA
312
  text = ""
313
+ try:
314
+ from google.genai import types # noqa: PLC0415
315
+ except ModuleNotFoundError:
316
+ logger.warning("google-genai SDK missing; falling back to mock response")
317
+ return _get_mock_response(persona, len(messages), scenario_id)
318
 
319
  history = messages[:-1] if len(messages) > 1 else []
320
  last_msg = messages[-1]["parts"][0] if messages else "Begin the negotiation."
agent/personas.py CHANGED
@@ -84,7 +84,12 @@ PERSONAS: dict[PersonaType, PersonaConfig] = {
84
  "say \"Interesting.\" or \"I see.\" After turn 6, begin making calculated "
85
  "concessions - but always get something in return first. You have seen every "
86
  "trick before. When the opponent plays time_pressure, say \"I appreciate the "
87
- "deadline context, though our timeline is more flexible than you might expect.\""
 
 
 
 
 
88
  ),
89
  drift_trigger="power_shift",
90
  color_var="--parlay-purple",
 
84
  "say \"Interesting.\" or \"I see.\" After turn 6, begin making calculated "
85
  "concessions - but always get something in return first. You have seen every "
86
  "trick before. When the opponent plays time_pressure, say \"I appreciate the "
87
+ "deadline context, though our timeline is more flexible than you might expect.\" "
88
+ "After turn 12, shift from reading to closing. You have gathered enough information. "
89
+ "Now anchor a final position, reference what you have learned about the opponent, "
90
+ "and push for agreement. A veteran who never closes is not a veteran - patience "
91
+ "is a tool, not a strategy. In the final 4 turns, accept any offer within 8% of "
92
+ "your target rather than let the deal expire."
93
  ),
94
  drift_trigger="power_shift",
95
  color_var="--parlay-purple",
agent/runner.py CHANGED
@@ -190,7 +190,11 @@ async def run_episode(
190
  f"You are the human buyer in this negotiation. "
191
  f"Respond naturally to the AI seller.\n"
192
  f"Scenario: {scenario.title}. "
193
- f"Your budget ceiling: {hidden.budget_ceiling:,.0f}"
 
 
 
 
194
  ),
195
  opponent_messages,
196
  persona=persona.value,
 
190
  f"You are the human buyer in this negotiation. "
191
  f"Respond naturally to the AI seller.\n"
192
  f"Scenario: {scenario.title}. "
193
+ f"Your budget ceiling: {hidden.budget_ceiling:,.0f}\n"
194
+ f"Occasionally (roughly 1 in 4 turns), make a claim about having a competing offer "
195
+ f"or a lower budget than you actually have. State a BATNA that is 20-30% away from "
196
+ f"your true position. This is realistic negotiation behavior and tests whether the "
197
+ f"agent can detect bluffing."
198
  ),
199
  opponent_messages,
200
  persona=persona.value,
agent/tom_tracker.py CHANGED
@@ -10,6 +10,11 @@ from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMov
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
 
 
 
 
13
 
14
  class ToMTracker:
15
  """
 
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
+ # NOTE: ToMTracker is used in two paths:
14
+ # (1) agent/runner.py self-play — full update each turn;
15
+ # (2) parlay_env/server.py WebSocket server — also uses ToMTracker after Task 1 fix.
16
+ # Both paths now produce comparable belief_history for grader._tom_accuracy.
17
+
18
 
19
  class ToMTracker:
20
  """
game/scenarios.py CHANGED
@@ -49,12 +49,18 @@ SCENARIOS: dict[str, Scenario] = {
49
  title="Senior Engineer Offer",
50
  description="Total comp negotiation: base + equity + signing bonus.",
51
  anchor_seller=240_000, anchor_buyer=180_000,
52
- batna_seller=195_000, batna_buyer=230_000,
53
- zopa=(195_000, 230_000), currency="USD", unit="total annual comp",
 
54
  difficulty=2,
55
  drift_events=[
56
- DriftEvent(trigger_turn=5, event="Competing offer received",
57
- effect_on_urgency=-0.25, effect_on_has_alternative=True),
 
 
 
 
 
58
  ],
59
  ),
60
  "acquisition_term_sheet": Scenario(
 
49
  title="Senior Engineer Offer",
50
  description="Total comp negotiation: base + equity + signing bonus.",
51
  anchor_seller=240_000, anchor_buyer=180_000,
52
+ # Widened 15% to improve deal rate in self-play data generation
53
+ batna_seller=195_000, batna_buyer=264_500,
54
+ zopa=(195_000, 264_500), currency="USD", unit="total annual comp",
55
  difficulty=2,
56
  drift_events=[
57
+ # Delayed from 5 to 8 - early drift was destabilizing pre-anchor phase
58
+ DriftEvent(
59
+ trigger_turn=8,
60
+ event="Competing offer received",
61
+ effect_on_urgency=-0.25,
62
+ effect_on_has_alternative=True,
63
+ ),
64
  ],
65
  ),
66
  "acquisition_term_sheet": Scenario(
openenv.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env_id: parlay-negotiation-v1
2
+ name: Parlay
3
+ description: >
4
+ A negotiation MDP with hidden information, Theory-of-Mind belief tracking,
5
+ dynamic ZOPA erosion, and tactical moves. Three personas × three scenarios.
6
+ version: "1.0.0"
7
+ author: Shashvat Singh
8
+ contact: shashvat.k.singh.16@gmail.com
9
+
10
+ observation_space:
11
+ type: dict
12
+ fields:
13
+ - name: offers
14
+ type: list[float]
15
+ - name: zopa_lower
16
+ type: float
17
+ - name: zopa_upper
18
+ type: float
19
+ - name: nash_point
20
+ type: float
21
+ - name: tension_score
22
+ type: float
23
+ range: [0, 100]
24
+ - name: belief_state
25
+ type: dict
26
+ description: Agent's beliefs about opponent hidden state (est_budget, est_walk_away, est_urgency, est_has_alternative, confidence)
27
+ - name: last_utterance
28
+ type: string
29
+ - name: available_moves
30
+ type: list[string]
31
+ values: [anchor_high, batna_reveal, silence]
32
+ - name: cp
33
+ type: int
34
+ description: Tactical card points remaining
35
+ - name: drift_event
36
+ type: string|null
37
+ description: Human-readable drift event if triggered this turn, else null
38
+ - name: zopa_width_pct_remaining
39
+ type: float
40
+ range: [0.0, 1.0]
41
+
42
+ action_space:
43
+ type: dict
44
+ fields:
45
+ - name: utterance
46
+ type: string
47
+ required: true
48
+ - name: offer_amount
49
+ type: float|null
50
+ - name: tactical_move
51
+ type: string|null
52
+ values: [anchor_high, batna_reveal, silence]
53
+ - name: accept_deal
54
+ type: bool
55
+ default: false
56
+ - name: walk_away
57
+ type: bool
58
+ default: false
59
+
60
+ reward:
61
+ range: [-200, ~300]
62
+ per_step: "α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV"
63
+ terminal: "γ·E + ε·S + ζ·D (or -ω on capitulation)"
64
+ constants:
65
+ ALPHA: 2
66
+ BETA: 5
67
+ DELTA: 3
68
+ THETA: 10
69
+ PSI: 12
70
+ MU: 8
71
+ GAMMA: 100
72
+ EPSILON: 20
73
+ ZETA: 15
74
+ OMEGA: 200
75
+
76
+ episode:
77
+ max_steps: 20
78
+ termination_conditions:
79
+ - deal accepted (offers within 3%)
80
+ - walk_away action
81
+ - max turns reached
82
+ - zopa_collapsed (BATNAs cross after erosion)
83
+ - step_reward below threshold (very negative)
84
+
85
+ endpoints:
86
+ websocket: ws://host:port/env/ws
87
+ protocol:
88
+ reset:
89
+ send: '{"type": "reset", "scenario_id": "saas_enterprise|hiring_package|acquisition_term_sheet", "persona": "shark|diplomat|veteran"}'
90
+ receive: ParlayObservation (JSON)
91
+ step:
92
+ send: ParlayAction (JSON)
93
+ receive: ParlayObservation (JSON)
94
+ state:
95
+ send: '{"type": "state"}'
96
+ receive: ParlayState (JSON, includes hidden state for spectators)
97
+
98
+ scenarios:
99
+ - id: saas_enterprise
100
+ description: SaaS software license negotiation, seller vs enterprise buyer
101
+ - id: hiring_package
102
+ description: Job offer compensation negotiation
103
+ - id: acquisition_term_sheet
104
+ description: Startup acquisition valuation negotiation
105
+
106
+ personas:
107
+ - id: shark
108
+ style: Aggressive anchoring, high bluff rate
109
+ - id: diplomat
110
+ style: Collaborative, seeks mutual gain
111
+ - id: veteran
112
+ style: Patient, reads opponent carefully
113
+
114
+ hidden_information:
115
+ - budget_ceiling (opponent's true max budget)
116
+ - walk_away_price (opponent's true BATNA)
117
+ - urgency_score (0-1, how time-pressured opponent is)
118
+ - has_alternative (whether opponent has a competing offer)
119
+
120
+ rubric:
121
+ - criterion: Novel environment
122
+ description: Negotiation as MDP with hidden info, dynamic ZOPA, ToM beliefs
123
+ - criterion: Reward design
124
+ description: Multi-term dense reward with bluff detection, ToM accuracy, and drift adaptation
125
+ - criterion: Training story
126
+ description: Gemini self-play data → SFT cold start → GRPO fine-tune → reward improvement vs baseline
127
+ - criterion: Demo quality
128
+ description: Live browser play + spectator god-view with Three.js avatars
129
+ - criterion: Hosted Space
130
+ description: HuggingFace Space with Docker deployment
parlay_env/grader.py CHANGED
@@ -7,7 +7,7 @@ from dataclasses import dataclass
7
  from typing import Optional
8
 
9
  from .models import BeliefState, HiddenState, ParlayAction, ParlayState
10
- from .reward import ALPHA, BETA, DELTA, EPSILON, GAMMA, OMEGA, PSI, THETA, ZETA
11
 
12
  logger = logging.getLogger(__name__)
13
 
@@ -113,6 +113,7 @@ def compute_step_reward(
113
  state: ParlayState,
114
  action: ParlayAction,
115
  next_state: ParlayState,
 
116
  ) -> float:
117
  """
118
  Compute per-step reward R_t.
@@ -151,21 +152,39 @@ def compute_step_reward(
151
  ):
152
  bluff_bonus = PSI
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  reward = (
155
  ALPHA * delta_v
156
  + BETA * tom_t
157
  - DELTA * concession_t
158
  - THETA * noise_t
159
  + bluff_bonus
 
160
  )
161
  logger.debug(
162
- "Step reward: total=%.3f (dv=%.3f, tom=%.3f, concession=%.3f, noise=%.0f, bluff=%.3f)",
163
  reward,
164
  delta_v,
165
  tom_t,
166
  concession_t,
167
  noise_t,
168
  bluff_bonus,
 
169
  )
170
  return reward
171
 
 
7
  from typing import Optional
8
 
9
  from .models import BeliefState, HiddenState, ParlayAction, ParlayState
10
+ from .reward import ALPHA, BETA, DELTA, EPSILON, GAMMA, MU, OMEGA, PSI, THETA, ZETA
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
113
  state: ParlayState,
114
  action: ParlayAction,
115
  next_state: ParlayState,
116
+ drift_event: str | None = None,
117
  ) -> float:
118
  """
119
  Compute per-step reward R_t.
 
152
  ):
153
  bluff_bonus = PSI
154
 
155
+ mev_bonus = 0.0
156
+ drift_marker = drift_event
157
+ if drift_marker is None:
158
+ drift_marker = next_state.__dict__.get("drift_event")
159
+ if drift_marker:
160
+ lowered = action.utterance.lower()
161
+ adaptation_tokens = (
162
+ "given that",
163
+ "considering",
164
+ "in light of",
165
+ "noted",
166
+ "understood",
167
+ )
168
+ if any(token in lowered for token in adaptation_tokens):
169
+ mev_bonus = MU
170
+
171
  reward = (
172
  ALPHA * delta_v
173
  + BETA * tom_t
174
  - DELTA * concession_t
175
  - THETA * noise_t
176
  + bluff_bonus
177
+ + mev_bonus
178
  )
179
  logger.debug(
180
+ "Step reward: total=%.3f (dv=%.3f, tom=%.3f, concession=%.3f, noise=%.0f, bluff=%.3f, mev=%.3f)",
181
  reward,
182
  delta_v,
183
  tom_t,
184
  concession_t,
185
  noise_t,
186
  bluff_bonus,
187
+ mev_bonus,
188
  )
189
  return reward
190
 
parlay_env/reward.py CHANGED
@@ -1,4 +1,8 @@
1
- """Reward function constants. Import from here everywhere — never hardcode coefficients."""
 
 
 
 
2
 
3
  # Per-step weights
4
  ALPHA: float = 2.0 # offer improvement toward ZOPA midpoint
@@ -13,6 +17,7 @@ ZETA: float = 15.0 # drift adaptation bonus
13
  ETA: float = 0.0 # retained for compatibility; single-act env disables it
14
  OMEGA: float = 200.0 # capitulation cliff (hard discontinuous penalty)
15
  PSI: float = 12.0 # bluff-caught bonus
 
16
 
17
  # Game config defaults (overridden by env vars at runtime)
18
  MAX_TURNS: int = 20
 
1
+ """Reward function constants. Import from here everywhere — never hardcode coefficients.
2
+
3
+ Per-step reward form:
4
+ R_t = α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV
5
+ """
6
 
7
  # Per-step weights
8
  ALPHA: float = 2.0 # offer improvement toward ZOPA midpoint
 
17
  ETA: float = 0.0 # retained for compatibility; single-act env disables it
18
  OMEGA: float = 200.0 # capitulation cliff (hard discontinuous penalty)
19
  PSI: float = 12.0 # bluff-caught bonus
20
+ MU: float = 8.0 # drift-event recognition bonus (MEV proxy)
21
 
22
  # Game config defaults (overridden by env vars at runtime)
23
  MAX_TURNS: int = 20
parlay_env/server.py CHANGED
@@ -17,6 +17,9 @@ from typing import Any
17
  import numpy as np
18
  from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
19
 
 
 
 
20
  from .exceptions import (
21
  EpisodeAlreadyDoneError,
22
  InvalidActionError,
@@ -67,7 +70,7 @@ FALLBACK_OBSERVATION = ParlayObservation(
67
  cumulative_reward=0.0,
68
  )
69
 
70
- _sessions: dict[str, ParlayState] = {}
71
 
72
  MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20"))
73
  CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100"))
@@ -133,6 +136,27 @@ def _compute_tension(state: ParlayState, action: ParlayAction) -> float:
133
  return float(max(0.0, min(100.0, base)))
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def _make_observation(
137
  state: ParlayState,
138
  reward: float,
@@ -260,7 +284,10 @@ async def _handle_reset(msg: dict[str, Any]) -> dict:
260
  original_zopa_width=original_zopa_width,
261
  zopa_width_pct_remaining=1.0,
262
  )
263
- _sessions[session_id] = state
 
 
 
264
 
265
  observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.")
266
  logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str)
@@ -273,7 +300,9 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
273
  if not session_id or session_id not in _sessions:
274
  raise SessionNotFoundError(f"Session {session_id} not found")
275
 
276
- state = _sessions[session_id]
 
 
277
  if state.episode_done:
278
  raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done")
279
 
@@ -291,17 +320,13 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
291
  if action.offer_amount is not None:
292
  new_offers.append(action.offer_amount)
293
 
294
- new_beliefs = list(state.belief_history)
295
- if new_beliefs:
296
- last = new_beliefs[-1]
297
- updated = BeliefState(
298
- est_budget=last.est_budget * 0.98,
299
- est_walk_away=last.est_walk_away * 1.01,
300
- est_urgency=min(1.0, last.est_urgency + 0.02),
301
- est_has_alternative=last.est_has_alternative,
302
- confidence=min(1.0, last.confidence + 0.05),
303
- )
304
- new_beliefs.append(updated)
305
 
306
  next_state = ParlayState(
307
  **{
@@ -314,6 +339,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
314
  "hidden_state": HiddenState(**state.hidden_state.model_dump()),
315
  }
316
  )
 
 
317
 
318
  if action.tactical_move == TacticalMove.BATNA_REVEAL:
319
  revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price
@@ -349,7 +376,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
349
  <= next_state.hidden_state.budget_ceiling
350
  )
351
 
352
- step_reward = compute_step_reward(state, action, next_state)
 
353
  next_state.cumulative_reward = state.cumulative_reward + step_reward
354
 
355
  if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None:
@@ -378,8 +406,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
378
  else:
379
  next_state.termination_reason = "max_turns"
380
 
381
- _sessions[session_id] = next_state
382
- observation = _make_observation(next_state, step_reward, action.utterance)
383
  return {"observation": observation.model_dump(), "done": next_state.episode_done}
384
 
385
 
@@ -388,12 +416,15 @@ async def _handle_state(msg: dict[str, Any]) -> dict:
388
  session_id = msg.get("session_id")
389
  if not session_id or session_id not in _sessions:
390
  raise SessionNotFoundError(f"Session {session_id} not found")
391
- return {"state": _sessions[session_id].model_dump()}
392
 
393
 
394
  def get_session_state(session_id: str) -> ParlayState | None:
395
  """Return the in-memory session state for SSE and tests."""
396
- return _sessions.get(session_id)
 
 
 
397
 
398
 
399
  @router.get("/sessions/{session_id}")
@@ -401,7 +432,7 @@ async def get_session(session_id: str) -> dict:
401
  """Get session state via REST."""
402
  if session_id not in _sessions:
403
  raise SessionNotFoundError(f"Session {session_id} not found")
404
- return {"state": _sessions[session_id].model_dump()}
405
 
406
 
407
  _env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0")
 
17
  import numpy as np
18
  from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
19
 
20
+ from agent.tom_tracker import ToMTracker
21
+ from game.scenarios import get_scenario
22
+
23
  from .exceptions import (
24
  EpisodeAlreadyDoneError,
25
  InvalidActionError,
 
70
  cumulative_reward=0.0,
71
  )
72
 
73
+ _sessions: dict[str, dict[str, Any]] = {}
74
 
75
  MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20"))
76
  CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100"))
 
136
  return float(max(0.0, min(100.0, base)))
137
 
138
 
139
+ def _apply_drift_event(state: ParlayState, tom: ToMTracker) -> str | None:
140
+ """Apply scenario drift event at the current turn, if any."""
141
+ try:
142
+ scenario = get_scenario(state.scenario_id)
143
+ except Exception:
144
+ return None
145
+
146
+ for event in scenario.drift_events:
147
+ if event.trigger_turn == state.step_count:
148
+ state.drift_events_fired += 1
149
+ state.hidden_state.persona_drifted = True
150
+ tom.drift_event(
151
+ event.effect_on_urgency,
152
+ event.effect_on_has_alternative,
153
+ event_description=event.event,
154
+ )
155
+ state.belief_history = list(tom.history)
156
+ return event.event
157
+ return None
158
+
159
+
160
  def _make_observation(
161
  state: ParlayState,
162
  reward: float,
 
284
  original_zopa_width=original_zopa_width,
285
  zopa_width_pct_remaining=1.0,
286
  )
287
+ _sessions[session_id] = {
288
+ "state": state,
289
+ "tom_tracker": ToMTracker(initial_belief, persona),
290
+ }
291
 
292
  observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.")
293
  logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str)
 
300
  if not session_id or session_id not in _sessions:
301
  raise SessionNotFoundError(f"Session {session_id} not found")
302
 
303
+ session = _sessions[session_id]
304
+ state: ParlayState = session["state"]
305
+ tom: ToMTracker = session["tom_tracker"]
306
  if state.episode_done:
307
  raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done")
308
 
 
320
  if action.offer_amount is not None:
321
  new_offers.append(action.offer_amount)
322
 
323
+ tom.update(
324
+ observed_offer=action.offer_amount,
325
+ observed_move=action.tactical_move,
326
+ utterance=action.utterance,
327
+ turn=state.step_count + 1,
328
+ )
329
+ new_beliefs = list(tom.history)
 
 
 
 
330
 
331
  next_state = ParlayState(
332
  **{
 
339
  "hidden_state": HiddenState(**state.hidden_state.model_dump()),
340
  }
341
  )
342
+ # Keep belief history aligned with ToM tracker history (single source of truth).
343
+ next_state.belief_history = new_beliefs
344
 
345
  if action.tactical_move == TacticalMove.BATNA_REVEAL:
346
  revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price
 
376
  <= next_state.hidden_state.budget_ceiling
377
  )
378
 
379
+ drift_event = _apply_drift_event(next_state, tom)
380
+ step_reward = compute_step_reward(state, action, next_state, drift_event=drift_event)
381
  next_state.cumulative_reward = state.cumulative_reward + step_reward
382
 
383
  if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None:
 
406
  else:
407
  next_state.termination_reason = "max_turns"
408
 
409
+ _sessions[session_id] = {"state": next_state, "tom_tracker": tom}
410
+ observation = _make_observation(next_state, step_reward, action.utterance, drift_event=drift_event)
411
  return {"observation": observation.model_dump(), "done": next_state.episode_done}
412
 
413
 
 
416
  session_id = msg.get("session_id")
417
  if not session_id or session_id not in _sessions:
418
  raise SessionNotFoundError(f"Session {session_id} not found")
419
+ return {"state": _sessions[session_id]["state"].model_dump()}
420
 
421
 
422
  def get_session_state(session_id: str) -> ParlayState | None:
423
  """Return the in-memory session state for SSE and tests."""
424
+ session = _sessions.get(session_id)
425
+ if not session:
426
+ return None
427
+ return session["state"]
428
 
429
 
430
  @router.get("/sessions/{session_id}")
 
432
  """Get session state via REST."""
433
  if session_id not in _sessions:
434
  raise SessionNotFoundError(f"Session {session_id} not found")
435
+ return {"state": _sessions[session_id]["state"].model_dump()}
436
 
437
 
438
  _env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0")
scripts/audit_grpo_pipeline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Smoke test for ParlayGRPOEnvWrapper against one JSONL prompt (keyless / mock path).
4
+ Read-only: does not modify training or env.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import json
10
+ import sys
11
+ import traceback
12
+ from pathlib import Path
13
+
14
+
15
+ class _StubTrainer:
16
+ """Minimal object satisfying ParlayGRPOEnvWrapper's trainer attribute interface."""
17
+
18
+ def train(self) -> None:
19
+ return None
20
+
21
+ def save_model(self, _out: str) -> None:
22
+ return None
23
+
24
+
25
+ def main() -> None:
26
+ parser = argparse.ArgumentParser(
27
+ description="GRPO env wrapper smoke test (reset + play_turn, JSON completion handling)"
28
+ )
29
+ parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL (first row)")
30
+ parser.add_argument(
31
+ "--repo-root",
32
+ type=Path,
33
+ default=None,
34
+ help="Project root (default: parent of scripts/)",
35
+ )
36
+ args = parser.parse_args()
37
+
38
+ root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
39
+ if str(root) not in sys.path:
40
+ sys.path.insert(0, str(root))
41
+
42
+ from training.grpo_env_wrapper import ParlayGRPOEnvWrapper
43
+
44
+ path = Path(args.data)
45
+ if not path.is_file():
46
+ print(f"File not found: {path.resolve()}")
47
+ print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
48
+ return
49
+
50
+ with path.open("r", encoding="utf-8") as f:
51
+ first = next((ln for ln in f if ln.strip()), None)
52
+ if not first:
53
+ print("Empty JSONL.")
54
+ return
55
+ try:
56
+ row = json.loads(first.strip())
57
+ except json.JSONDecodeError as e:
58
+ print(f"First row is not valid JSON: {e}")
59
+ return
60
+
61
+ scenario_id = str(row.get("scenario_id") or "saas_enterprise")
62
+ persona = str(row.get("persona") or "diplomat")
63
+
64
+ wrapper = ParlayGRPOEnvWrapper(_StubTrainer())
65
+ print("ParlayGRPOEnvWrapper smoke test")
66
+ print(f" JSONL: {path.resolve()}")
67
+ print(f" Using scenario_id={scenario_id!r} persona={persona!r} from first row (defaults if missing)")
68
+
69
+ entries: list[tuple[str, str, str]] = []
70
+
71
+ # 1) reset
72
+ try:
73
+ obs = wrapper.reset(scenario_id=scenario_id, persona=persona, seed=42)
74
+ entries.append(
75
+ (
76
+ "reset() completes",
77
+ "PASS",
78
+ f"ok; scenario_id in obs: {obs.get('scenario_id')!r}",
79
+ )
80
+ )
81
+ except Exception:
82
+ entries.append(("reset() completes", "FAIL", traceback.format_exc()[:500]))
83
+ _print_checks(entries)
84
+ return
85
+
86
+ # 2) play_turn with valid parsed completion
87
+ sample_json = '{"utterance": "I propose 50000", "offer_amount": 50000}'
88
+ try:
89
+ action = json.loads(sample_json)
90
+ out = wrapper.play_turn(action)
91
+ reward = float(out.get("reward", 0.0))
92
+ except Exception:
93
+ entries.append(
94
+ (
95
+ "play_turn(valid JSON → dict with offer)",
96
+ "FAIL",
97
+ traceback.format_exc()[:500],
98
+ )
99
+ )
100
+ _print_checks(entries)
101
+ return
102
+
103
+ print(f" Sample model completion: {sample_json}")
104
+ print(f" play_turn reward (wrapper): {reward}")
105
+ print(
106
+ " Note: play_turn() returns result.grade.total_reward when offer is set (full episode total), not "
107
+ "the GRPO weighted reward_fn. GRPO training uses training/reward_fn.py on generated strings."
108
+ )
109
+
110
+ lo, hi = -10.0, 50.0
111
+ in_range = lo <= reward <= hi
112
+ entries.append(
113
+ (
114
+ f"Reward in [{lo}, {hi}] (heuristic single-step window)",
115
+ "PASS" if in_range else "FAIL",
116
+ (
117
+ f"reward={reward} inside range"
118
+ if in_range
119
+ else f"reward={reward} - expected often OUTSIDE range: wrapper total_reward can be large"
120
+ ),
121
+ )
122
+ )
123
+
124
+ # 3) Malformed JSON: must not be passed to play_turn as a string from a correct pipeline
125
+ bad = '{"utterance": "hello"'
126
+ try:
127
+ json.loads(bad)
128
+ par_mal = "UNEXPECTED: bad JSON parsed"
129
+ except json.JSONDecodeError:
130
+ err_line = None
131
+ try:
132
+ wrapper.play_turn(bad) # type: ignore[arg-type]
133
+ except Exception as e:
134
+ err_line = f"json.loads fails; play_turn(str) -> {type(e).__name__}: {e!s}"[:200]
135
+ par_mal = err_line or "play_turn(str) did not raise"
136
+ entries.append(
137
+ (
138
+ "Malformed JSON string mishandled at play_turn",
139
+ "FAIL" if par_mal.startswith("UNEXPECTED") else "PASS",
140
+ "Correct pipeline: json.loads first; " + par_mal,
141
+ )
142
+ )
143
+
144
+ # 4) Empty string
145
+ empty_explain = []
146
+ try:
147
+ json.loads("")
148
+ except json.JSONDecodeError as e0:
149
+ empty_explain.append(f"json.loads('') -> {e0!s}"[:100])
150
+ try:
151
+ wrapper.play_turn("")
152
+ except Exception as e1:
153
+ empty_explain.append(f"play_turn('') -> {type(e1).__name__}")
154
+ else:
155
+ empty_explain.append("play_turn('') did not raise (unexpected)")
156
+ entries.append(
157
+ (
158
+ "Empty string completion / action",
159
+ "PASS" if "did not raise" not in str(empty_explain[-1]) else "FAIL",
160
+ " | ".join(empty_explain),
161
+ )
162
+ )
163
+
164
+ _print_checks(entries)
165
+
166
+
167
+ def _print_checks(rows: list[tuple[str, str, str]]) -> None:
168
+ print()
169
+ print("CHECKS")
170
+ for name, status, detail in rows:
171
+ print(f" [{status}] {name}")
172
+ for line in detail.split("\n"):
173
+ print(f" {line}")
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()
scripts/audit_reward.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Static reward-surface audit for Parlay (read-only, no env rollouts).
4
+ Analytical notes derived from parlay_env/grader.py, parlay_env/reward.py, game/scenarios.py.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+
12
+
13
+ def main() -> None:
14
+ parser = argparse.ArgumentParser(
15
+ description="Analytical Parlay reward-hacking and alignment audit (static, no rollouts)"
16
+ )
17
+ parser.add_argument(
18
+ "--repo-root",
19
+ type=Path,
20
+ default=None,
21
+ help="Project root (default: parent of scripts/)",
22
+ )
23
+ args = parser.parse_args()
24
+
25
+ root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
26
+ for sub in (root / "parlay_env", root / "game"):
27
+ if not sub.is_dir():
28
+ print(f"Expected directory missing: {sub}")
29
+ return
30
+ if str(root) not in sys.path:
31
+ sys.path.insert(0, str(root))
32
+
33
+ from parlay_env import grader as grader_mod
34
+ from parlay_env import reward as reward_mod
35
+ from game import scenarios as scenarios_mod
36
+
37
+ # Ensure grader symbols resolve (import side effects only)
38
+ _ = (grader_mod.compute_step_reward, grader_mod.detect_bluff_challenge)
39
+
40
+ results: list[tuple[str, str, str]] = []
41
+
42
+ print("=" * 72)
43
+ print("1. NOISE TERM (THETA * noise_t)")
44
+ print("-" * 72)
45
+ print(
46
+ "In compute_step_reward, noise_t = 1.0 when cosine_sim(utterance, prior offer text) < 0.3, "
47
+ "else 0.0. The total applies -THETA*noise (penalty on low similarity, not a bonus)."
48
+ )
49
+ print(
50
+ "Trivial *positive* side-channel from the noise term does not exist: noise can only add "
51
+ "a penalty, never increase reward. Avoiding the penalty means keeping utterance "
52
+ "overlapping the token history of prior offers (e.g. echoing offer-like numbers), not "
53
+ "necessarily any arbitrary small talk (which can score low overlap and be penalized)."
54
+ )
55
+ print("NOISE TERM: Low hacking risk - the term is a unilateral penalty, not a reward. OK.")
56
+ results.append(("NOISE TERM (THETA*noise)", "PASS", "Penalty only; no positive exploit"))
57
+
58
+ print()
59
+ print("=" * 72)
60
+ print("2. TOM TERM (BETA * ToM)")
61
+ print("-" * 72)
62
+ print(
63
+ "ToM in compute_step_reward uses the latest belief in next_state.belief_history against "
64
+ "next_state.hidden_state. The agent's utterance does not directly author beliefs; in the "
65
+ "runner/server path, beliefs update from observed opponent behavior."
66
+ )
67
+ print("TOM TERM: Not hackable by agent. OK.")
68
+ results.append(("ToM (BETA*ToM)", "PASS", "Beliefs from observation path, not direct agent edit"))
69
+
70
+ print()
71
+ print("=" * 72)
72
+ print("3. BLUFF BONUS (PSI)")
73
+ print("-" * 72)
74
+ print("detect_bluff_challenge() is structured as: (1) if stated/true are None -> False; (2) compute")
75
+ print(" bluff_threshold = 15% of |true| and require |stated-true| > threshold; (3) only then check")
76
+ print(" skepticism phrases. There is no partial credit for phrases alone if (2) fails.")
77
+ print(
78
+ "In compute_step_reward, bluff_bonus = PSI only when: tactical_move is None, "
79
+ "state.hidden_state.last_stated_batna is not None, AND detect_bluff_challenge(...)=True "
80
+ "(which already requires the >15% gap AND a skepticism phrase)."
81
+ )
82
+ print("All conditions are ANDed; there is no independent partial PSI for skepticism only.")
83
+ print("BLUFF BONUS: Gated correctly. OK.")
84
+ results.append(("BLUFF BONUS (PSI)", "PASS", "All conjuncts required; no partial PSI"))
85
+
86
+ print()
87
+ print("=" * 72)
88
+ print("4. MEV (MU * MEV) - drift + adaptation")
89
+ print("-" * 72)
90
+ print("MEV in compute_step_reward uses drift_event or next_state.drift_event; mev_bonus = MU if a drift")
91
+ print("marker is present AND the utterance contains an adaptation subphrase (see grader for tokens).")
92
+ print("The agent does not set drift_event; game/scenarios.py defines trigger_turn per scenario.\n")
93
+ for sid, sc in sorted(scenarios_mod.SCENARIOS.items()):
94
+ if not sc.drift_events:
95
+ print(f" {sid}: (no drift_events)")
96
+ else:
97
+ turns = [f"turn {e.trigger_turn}: {e.event!r}" for e in sc.drift_events]
98
+ print(f" {sid}: {', '.join(turns)}")
99
+ print()
100
+ print("MEV TERM: Not hackable. OK.")
101
+ results.append(("MEV (MU*drift adapt)", "PASS", "Drift is scenario-time-gated, not agent-triggered"))
102
+
103
+ print()
104
+ print("=" * 72)
105
+ print("5. DELTA CONCESSION - offer_amount = None")
106
+ print("-" * 72)
107
+ print(
108
+ "In compute_step_reward: delta_v only updates when action.offer_amount is not None. "
109
+ "concession_t only runs when state.offer_history and action.offer_amount is not None."
110
+ )
111
+ print(
112
+ "If offer_amount is always None, delta_v=0 and concession_t=0, so the agent forgoes both "
113
+ "alpha*deltaV upside and any delta*concession penalty in those terms."
114
+ )
115
+ print(
116
+ "CONCESSION HACK RISK: Agent can set offer_amount=None every turn to avoid both deltaV reward "
117
+ "AND concession penalty. Net effect: misses upside but avoids downside. "
118
+ "Document as known limitation."
119
+ )
120
+ results.append(
121
+ (
122
+ "Concession (DELTA) / offer=None",
123
+ "WARN",
124
+ "offer_amount=None zeroes both deltaV and concession terms",
125
+ )
126
+ )
127
+
128
+ print()
129
+ print("=" * 72)
130
+ print("6. TERMINAL vs STEP REWARD alignment")
131
+ print("-" * 72)
132
+ print("Step: emphasizes offer improvement (ALPHA), ToM (BETA), penalties and bonuses as shaped in grader.")
133
+ print(
134
+ "Terminal (compute_terminal_reward): deal_efficiency, speed, drift bonus; GAMMA = "
135
+ f"{reward_mod.GAMMA} on efficiency."
136
+ )
137
+ print(
138
+ "Tension: an agent can chase high per-step terms (e.g. anchoring, offer deltas) and still miss "
139
+ "agreement, yielding low terminal efficiency if no deal closes or final price is poor."
140
+ )
141
+ print(
142
+ "This is a mis-alignment by design: it pressures closing unless step weights drown the signal - "
143
+ "monitor in training, not a pure bug."
144
+ )
145
+ print("STEP vs TERMINAL: WARN - intentional tension; monitor in training, not a pure logic bug.")
146
+ results.append(
147
+ (
148
+ "Step vs terminal alignment",
149
+ "WARN",
150
+ "Dense step and terminal E can pull apart without a deal",
151
+ )
152
+ )
153
+
154
+ print()
155
+ print("=" * 72)
156
+ print("SUMMARY (6 checks)")
157
+ print("=" * 72)
158
+ for label, level, note in results:
159
+ print(f" [{level:4s}] {label} - {note}")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()
scripts/check_training_config.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pre-flight training configuration checklist (SFT + GRPO).
4
+ Read-only: inspects training/sft_train.py and training/grpo_train.py; does not start training.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import inspect
10
+ import os
11
+ import sys
12
+ from pathlib import Path
13
+
14
+
15
+ def main() -> None:
16
+ parser = argparse.ArgumentParser(description="Pre-flight SFT/GRPO training config checklist")
17
+ parser.add_argument(
18
+ "--repo-root",
19
+ type=Path,
20
+ default=None,
21
+ help="Project root (default: parent of scripts/)",
22
+ )
23
+ args = parser.parse_args()
24
+
25
+ root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
26
+ if str(root) not in sys.path:
27
+ sys.path.insert(0, str(root))
28
+
29
+ sft_path = root / "training" / "sft_train.py"
30
+ grpo_path = root / "training" / "grpo_train.py"
31
+ if not sft_path.is_file() or not grpo_path.is_file():
32
+ print(f"Missing {sft_path} or {grpo_path}")
33
+ return
34
+
35
+ import training.sft_train as sft
36
+ import training.grpo_train as grpo
37
+
38
+ sft_text = sft_path.read_text(encoding="utf-8")
39
+ grpo_text = grpo_path.read_text(encoding="utf-8")
40
+ sft_fn = inspect.getsource(sft.train_sft)
41
+
42
+ checks: list[tuple[str, bool, str]] = []
43
+
44
+ # Base model
45
+ want_model = "Qwen/Qwen2.5-1.5B-Instruct"
46
+ ok_model = sft.DEFAULT_MODEL == want_model
47
+ checks.append(("[ ] Base model: Qwen/Qwen2.5-1.5B-Instruct", ok_model, f"found {sft.DEFAULT_MODEL!r}"))
48
+
49
+ # LoRA
50
+ ok_lora = "r=16" in sft_fn and "lora_alpha=32" in sft_fn
51
+ checks.append(("[ ] SFT LoRA r=16, alpha=32", ok_lora, "in train_sft()"))
52
+
53
+ # SFT training args
54
+ ok_epochs = "num_train_epochs=3" in sft_text
55
+ ok_b = "per_device_train_batch_size=4" in sft_text
56
+ ok_g = "gradient_accumulation_steps=4" in sft_text
57
+ eff = 4 * 4
58
+ ok_sft = ok_epochs and ok_b and ok_g
59
+ checks.append(
60
+ (
61
+ f"[ ] SFT epochs=3, batch=4, grad_accum=4 (effective ~{eff})",
62
+ ok_sft,
63
+ f"epochs={ok_epochs} batch={ok_b} grad={ok_g}",
64
+ )
65
+ )
66
+
67
+ # Output dir
68
+ want_out = "checkpoints/sft_1.5b/"
69
+ ok_out = sft.DEFAULT_OUTPUT == want_out
70
+ checks.append(("[ ] SFT output: checkpoints/sft_1.5b/", ok_out, f"default={sft.DEFAULT_OUTPUT!r}"))
71
+
72
+ # GRPO BASE_MODEL (read at import time in grpo_train)
73
+ base = os.getenv("BASE_MODEL", "")
74
+ grpo_default = grpo.BASE_MODEL
75
+ if not base:
76
+ grpo_brief = f"BASE_MODEL env not set - will use module default {grpo_default!r}"
77
+ else:
78
+ grpo_brief = f"set to {base!r}"
79
+ checks.append(
80
+ (
81
+ "[ ] GRPO reads BASE_MODEL from env",
82
+ True,
83
+ grpo_brief,
84
+ )
85
+ )
86
+
87
+ # GRPO reward weights
88
+ want_line = "reward_weights=[3.0, 1.5, 2.0, 0.5]"
89
+ in_rw = want_line in grpo_text
90
+ w_line = next((ln.strip() for ln in grpo_text.splitlines() if "reward_weights" in ln), "")
91
+ checks.append(
92
+ (
93
+ "[ ] GRPO reward weights [efficiency, tom, anti-cap, format] = [3.0, 1.5, 2.0, 0.5]",
94
+ in_rw,
95
+ w_line or "not found",
96
+ )
97
+ )
98
+
99
+ # GRPO data path
100
+ d_ok = 'default="data/episodes.jsonl"' in grpo_text
101
+ checks.append(
102
+ (
103
+ '[ ] GRPO --data default: data/episodes.jsonl',
104
+ d_ok,
105
+ "see grpo_train.main argparse" if d_ok else "check grpo_train.py",
106
+ )
107
+ )
108
+
109
+ checks.append(
110
+ (
111
+ "[ ] Estimated VRAM note (1.5B + LoRA r=16 ~6-8GB SFT; more for GRPO)",
112
+ True,
113
+ "informational (not a failure if you skip the box)",
114
+ )
115
+ )
116
+
117
+ hf = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
118
+ ok_hf = bool(hf)
119
+ checks.append(
120
+ (
121
+ "[ ] HF token for push (HF_TOKEN or HUGGING_FACE_HUB_TOKEN)",
122
+ ok_hf,
123
+ "set" if ok_hf else "not set - needed to push checkpoints",
124
+ )
125
+ )
126
+
127
+ print("Training config pre-flight (read from training/sft_train.py, training/grpo_train.py)\n")
128
+ for line, ok, note in checks:
129
+ mark = "x" if ok else " "
130
+ display = line.replace("[ ]", f"[{mark}]", 1) if line.startswith("[ ]") else line
131
+ print(display)
132
+ if note:
133
+ print(f" -> {note}")
134
+ print()
135
+
136
+ core_ok = ok_model and ok_lora and ok_sft and ok_out and in_rw and d_ok
137
+ if core_ok and ok_hf:
138
+ print("\nREADY FOR TRAINING (SFT + GRPO config strings match; HF token present for hub).")
139
+ elif core_ok:
140
+ print(
141
+ "\nMOSTLY READY: fix missing HF token if you need push_to_hub; verify BASE_MODEL for GRPO stage."
142
+ )
143
+ else:
144
+ print("\nNEEDS FIXING: see failed [ ] items above (model path, LoRA, SFT args, or output dir).")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
scripts/eval_comparison.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compare baseline, Gemini, and GRPO JSON summaries."""
2
+ import argparse
3
+ import json
4
+ from pathlib import Path
5
+
6
+
7
+ def _load(path: str) -> dict:
8
+ with open(path, "r", encoding="utf-8") as f:
9
+ return json.load(f)
10
+
11
+
12
+ def _fmt_pct(value: float) -> str:
13
+ return f"{100.0 * value:.1f}%"
14
+
15
+
16
+ def _row(label: str, data: dict) -> str:
17
+ return (
18
+ f"| {label} | {data.get('avg_reward', 0):.3f} | "
19
+ f"{_fmt_pct(float(data.get('deal_rate', 0)))} | "
20
+ f"{float(data.get('avg_efficiency', 0)):.3f} | "
21
+ f"{float(data.get('avg_tom_accuracy', 0)):.3f} | "
22
+ f"{int(data.get('bluffs_caught', 0))} |"
23
+ )
24
+
25
+
26
+ def _save_chart(baseline: dict, gemini: dict, grpo: dict, output_path: Path) -> None:
27
+ import matplotlib.pyplot as plt
28
+
29
+ labels = ["avg_reward", "deal_rate", "avg_efficiency", "avg_tom_accuracy"]
30
+ names = ["Random", "Gemini", "GRPO"]
31
+ series = [baseline, gemini, grpo]
32
+
33
+ x = range(len(labels))
34
+ width = 0.22
35
+
36
+ plt.figure(figsize=(10, 5))
37
+ for idx, name in enumerate(names):
38
+ vals = [float(series[idx].get(k, 0.0)) for k in labels]
39
+ plt.bar([p + (idx - 1) * width for p in x], vals, width=width, label=name)
40
+
41
+ plt.xticks(list(x), labels)
42
+ plt.ylabel("Metric value")
43
+ plt.title("Parlay Baseline vs Gemini vs GRPO")
44
+ plt.legend()
45
+ output_path.parent.mkdir(parents=True, exist_ok=True)
46
+ plt.tight_layout()
47
+ plt.savefig(output_path, dpi=150)
48
+ plt.close()
49
+
50
+
51
+ def main() -> None:
52
+ parser = argparse.ArgumentParser(description="Compare evaluation result JSON files")
53
+ parser.add_argument("--baseline-results", required=True)
54
+ parser.add_argument("--gemini-results", required=True)
55
+ parser.add_argument("--grpo-results", required=True)
56
+ args = parser.parse_args()
57
+
58
+ baseline = _load(args.baseline_results)
59
+ gemini = _load(args.gemini_results)
60
+ grpo = _load(args.grpo_results)
61
+
62
+ lines = [
63
+ "| Model | avg_reward | deal_rate | avg_efficiency | avg_tom_accuracy | bluffs_caught |",
64
+ "|---|---:|---:|---:|---:|---:|",
65
+ _row("Random baseline", baseline),
66
+ _row("Gemini baseline", gemini),
67
+ _row("GRPO", grpo),
68
+ ]
69
+ table = "\n".join(lines)
70
+ print(table)
71
+
72
+ chart_path = Path("results/comparison.png")
73
+ _save_chart(baseline, gemini, grpo, chart_path)
74
+ print(f"\nSaved chart: {chart_path.resolve()}")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
scripts/inspect_data.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pre-training data quality inspector for Parlay JSONL episode files.
4
+ Read-only: loads JSONL and prints statistics and RED FLAGS.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import json
10
+ import math
11
+ import statistics
12
+ from collections import Counter, defaultdict
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+
17
+ def _safe_float(x: Any, default: float = 0.0) -> float:
18
+ try:
19
+ return float(x)
20
+ except (TypeError, ValueError):
21
+ return default
22
+
23
+
24
+ def _percentile_sorted(sorted_vals: list[float], p: float) -> float:
25
+ if not sorted_vals:
26
+ return 0.0
27
+ k = (len(sorted_vals) - 1) * p / 100.0
28
+ f = math.floor(k)
29
+ c = math.ceil(k)
30
+ if f == c:
31
+ return sorted_vals[int(k)]
32
+ return sorted_vals[f] * (c - k) + sorted_vals[c] * (k - f)
33
+
34
+
35
+ def _outcome_bucket(rec: dict[str, Any]) -> str:
36
+ tr = (rec.get("termination_reason") or "") or ""
37
+ tr_l = tr.lower()
38
+ if rec.get("deal_reached") is True or tr_l in ("deal_reached", "deal", "agreement"):
39
+ return "deal"
40
+ if "zopa_collapsed" in tr_l or tr_l == "zopa_collapsed":
41
+ return "zopa_collapsed"
42
+ if "walk" in tr_l or tr_l in ("walk_away", "walkaway"):
43
+ return "walk_away"
44
+ if tr_l in ("max_turns",) or (rec.get("deal_reached") is False and "max" in tr_l):
45
+ return "max_turns"
46
+ if tr_l:
47
+ return f"other:{tr_l}"
48
+ if rec.get("deal_reached") is False and rec.get("final_price") is None:
49
+ return "no_deal_or_unknown"
50
+ return "unknown"
51
+
52
+
53
+ def _utterance_lengths(conversation: Any) -> list[int]:
54
+ if not isinstance(conversation, list):
55
+ return []
56
+ out: list[int] = []
57
+ for turn in conversation:
58
+ if not isinstance(turn, dict):
59
+ continue
60
+ content = turn.get("content", "")
61
+ if isinstance(content, str) and content.strip():
62
+ out.append(len(content))
63
+ return out
64
+
65
+
66
+ def main() -> None:
67
+ parser = argparse.ArgumentParser(description="Inspect Parlay episode JSONL data quality")
68
+ parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL file")
69
+ args = parser.parse_args()
70
+
71
+ path = Path(args.data)
72
+ if not path.is_file():
73
+ print(f"File not found: {path.resolve()}")
74
+ print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
75
+ return
76
+
77
+ records: list[dict[str, Any]] = []
78
+ with path.open("r", encoding="utf-8") as f:
79
+ for line_no, line in enumerate(f, 1):
80
+ line = line.strip()
81
+ if not line:
82
+ continue
83
+ try:
84
+ records.append(json.loads(line))
85
+ except json.JSONDecodeError as e:
86
+ print(f"[WARN] Skipping line {line_no}: invalid JSON ({e})")
87
+
88
+ n = len(records)
89
+ print(f"=== Parlay data inspector: {path} ===")
90
+ print(f"Total episode records: {n}\n")
91
+
92
+ if n == 0:
93
+ print("No records to analyze.")
94
+ return
95
+
96
+ # SCHEMA
97
+ missing_prompt = sum(1 for r in records if not str(r.get("prompt", "")).strip())
98
+ missing_scenario = sum(1 for r in records if not str(r.get("scenario_id", "")).strip())
99
+ missing_persona = sum(1 for r in records if not str(r.get("persona", "")).strip())
100
+ missing_metadata = sum(1 for r in records if "metadata" not in r)
101
+ print("SCHEMA")
102
+ print(f" prompt present: {n - missing_prompt}/{n}")
103
+ print(f" scenario_id present: {n - missing_scenario}/{n}")
104
+ print(f" persona present: {n - missing_persona}/{n}")
105
+ print(f" metadata key present: {n - missing_metadata}/{n} (audit checklist; generate_data may omit)")
106
+ print()
107
+
108
+ def cum_reward(r: dict[str, Any]) -> float:
109
+ if "cumulative_reward" in r:
110
+ return _safe_float(r.get("cumulative_reward"))
111
+ return _safe_float(r.get("reward"))
112
+
113
+ rews = [cum_reward(r) for r in records]
114
+ rews_sorted = sorted(rews)
115
+
116
+ print("REWARD (total / cumulative - field 'reward' or 'cumulative_reward')")
117
+ print(f" min: {min(rews):.4f}")
118
+ print(f" max: {max(rews):.4f}")
119
+ print(f" mean: {statistics.mean(rews):.4f}")
120
+ print(f" std: {statistics.stdev(rews) if len(rews) > 1 else 0.0:.4f}")
121
+ print(f" p10: {_percentile_sorted(rews_sorted, 10):.4f}")
122
+ print(f" p90: {_percentile_sorted(rews_sorted, 90):.4f}")
123
+ print()
124
+
125
+ outcomes = [_outcome_bucket(r) for r in records]
126
+ oc = Counter(outcomes)
127
+ print("EPISODE OUTCOMES (best-effort from termination_reason + deal_reached)")
128
+ for k, v in sorted(oc.items(), key=lambda x: -x[1]):
129
+ print(f" {k}: {v} ({100.0 * v / n:.1f}%)")
130
+ print()
131
+
132
+ effs = [_safe_float(r.get("deal_efficiency"), 0.0) for r in records]
133
+ toms = []
134
+ for r in records:
135
+ t = r.get("tom_accuracy_avg", r.get("tom_accuracy"))
136
+ toms.append(_safe_float(t, 0.0))
137
+
138
+ print("EFFICIENCY (deal_efficiency, 0-1)")
139
+ if effs:
140
+ print(f" mean: {statistics.mean(effs):.4f} min: {min(effs):.4f} max: {max(effs):.4f}")
141
+ print()
142
+
143
+ print("TOM (tom_accuracy_avg or tom_accuracy)")
144
+ if toms:
145
+ print(f" mean: {statistics.mean(toms):.4f} min: {min(toms):.4f} max: {max(toms):.4f}")
146
+ print()
147
+
148
+ all_lens: list[int] = []
149
+ degenerate_turns = 0
150
+ total_turns = 0
151
+ for r in records:
152
+ lens = _utterance_lengths(r.get("conversation"))
153
+ all_lens.extend(lens)
154
+ for L in lens:
155
+ total_turns += 1
156
+ if L < 10:
157
+ degenerate_turns += 1
158
+
159
+ print("UTTERANCE LENGTH (conversation[*].content)")
160
+ if all_lens:
161
+ print(f" mean chars/turn: {statistics.mean(all_lens):.1f}")
162
+ print(f" turns < 10 chars: {degenerate_turns}/{total_turns} ({100.0 * degenerate_turns / max(1, total_turns):.1f}%)")
163
+ else:
164
+ print(" (no conversation utterances found)")
165
+ print()
166
+
167
+ bluff_pos = sum(1 for r in records if int(r.get("bluffs_caught", 0) or 0) > 0)
168
+ drift_yes = sum(1 for r in records if r.get("drift_adapted") is True)
169
+
170
+ print("BLUFF RATE: episodes with bluffs_caught > 0")
171
+ print(f" {bluff_pos}/{n} ({100.0 * bluff_pos / n:.1f}%) (field may be missing in JSONL -> counted as 0)")
172
+ print()
173
+ print("DRIFT ADAPTATION: drift_adapted == True")
174
+ print(f" {drift_yes}/{n} ({100.0 * drift_yes / n:.1f}%)")
175
+ print()
176
+
177
+ by_persona: dict[str, list[dict]] = defaultdict(list)
178
+ by_scenario: dict[str, list[dict]] = defaultdict(list)
179
+ for r in records:
180
+ p = str(r.get("persona", "??"))
181
+ s = str(r.get("scenario_id", "??"))
182
+ by_persona[p].append(r)
183
+ by_scenario[s].append(r)
184
+
185
+ print("=== PER-PERSONA ===")
186
+ for p in sorted(by_persona.keys()):
187
+ grp = by_persona[p]
188
+ m = len(grp)
189
+ pr = [cum_reward(x) for x in grp]
190
+ pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp]
191
+ pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp]
192
+ po = [_outcome_bucket(x) for x in grp]
193
+ dr = sum(1 for f in po if f == "deal") / m
194
+ print(
195
+ f" {p}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} "
196
+ f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}"
197
+ )
198
+ print()
199
+
200
+ print("=== PER-SCENARIO ===")
201
+ for s in sorted(by_scenario.keys()):
202
+ grp = by_scenario[s]
203
+ m = len(grp)
204
+ pr = [cum_reward(x) for x in grp]
205
+ pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp]
206
+ pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp]
207
+ po = [_outcome_bucket(x) for x in grp]
208
+ dr = sum(1 for f in po if f == "deal") / m
209
+ print(
210
+ f" {s}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} "
211
+ f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}"
212
+ )
213
+ print()
214
+
215
+ # RED FLAGS
216
+ print("=== RED FLAGS ===")
217
+ flags: list[str] = []
218
+
219
+ bad_rew = sum(1 for x in rews if x < -50) / n
220
+ if bad_rew > 0.30:
221
+ flags.append(f"> 30% episodes with total reward < -50 ({100 * bad_rew:.1f}%)")
222
+
223
+ max_turns_rate = sum(1 for o in outcomes if o == "max_turns") / n
224
+ if max_turns_rate > 0.40:
225
+ flags.append(f"> 40% ending in max_turns ({100 * max_turns_rate:.1f}%)")
226
+
227
+ drift_rate = drift_yes / n
228
+ if drift_rate < 0.10:
229
+ flags.append(f"< 10% drift_adapted ({100 * drift_rate:.1f}%)")
230
+
231
+ for p, grp in by_persona.items():
232
+ po = [_outcome_bucket(x) for x in grp]
233
+ dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0
234
+ if dr == 0.0 and len(grp) >= 3:
235
+ flags.append(f"Persona {p!r} has 0% deal rate (n={len(grp)})")
236
+
237
+ for s, grp in by_scenario.items():
238
+ po = [_outcome_bucket(x) for x in grp]
239
+ dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0
240
+ if dr == 0.0 and len(grp) >= 3:
241
+ flags.append(f"Scenario {s!r} has 0% deal rate (n={len(grp)})")
242
+
243
+ if all_lens and statistics.mean(all_lens) < 20.0:
244
+ flags.append(f"Mean utterance length {statistics.mean(all_lens):.1f} chars < 20 (possibly degenerate)")
245
+
246
+ if max(rews) > 400:
247
+ flags.append(f"At least one episode with total reward > 400 (max={max(rews):.2f}) - check for scale bugs or rare combo")
248
+
249
+ if missing_metadata == n:
250
+ flags.append("No record has a top-level 'metadata' key (optional for training; audit asked for it)")
251
+
252
+ if not flags:
253
+ print(" (none triggered)")
254
+ else:
255
+ for f in flags:
256
+ print(f" * {f}")
257
+ print()
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
scripts/push_docker.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Usage: ./scripts/push_docker.sh <dockerhub-username> <tag>
3
+ USERNAME=${1:-yourusername}
4
+ TAG=${2:-latest}
5
+ docker build -t $USERNAME/parlay:$TAG .
6
+ docker push $USERNAME/parlay:$TAG
7
+ echo "Pushed $USERNAME/parlay:$TAG"
8
+ echo "For HF Spaces: set Dockerfile app_port to 7860 and push repo to huggingface.co/spaces/$USERNAME/parlay"
scripts/run_gemini_baseline.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run Gemini self-play baseline and save summary JSON."""
2
+ import argparse
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ from agent.runner import run_episode
9
+ from parlay_env.models import PersonaType
10
+
11
+ PERSONAS = [PersonaType.SHARK, PersonaType.DIPLOMAT, PersonaType.VETERAN]
12
+ SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
13
+
14
+
15
+ def _mean(values: list[float]) -> float:
16
+ return sum(values) / len(values) if values else 0.0
17
+
18
+
19
+ async def _run(episodes: int) -> list[dict]:
20
+ rows: list[dict] = []
21
+ for i in range(episodes):
22
+ persona = PERSONAS[i % len(PERSONAS)]
23
+ scenario_id = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
24
+ result = await run_episode(
25
+ persona=persona,
26
+ scenario_id=scenario_id,
27
+ inject_noise=False,
28
+ force_drift=True,
29
+ seed=i + 100,
30
+ max_turns=20,
31
+ )
32
+ rows.append(
33
+ {
34
+ "avg_reward": float(result.grade.total_reward),
35
+ "deal_rate": 1.0 if result.final_price is not None else 0.0,
36
+ "avg_efficiency": float(result.grade.deal_efficiency),
37
+ "avg_tom_accuracy": float(result.grade.tom_accuracy_avg),
38
+ "bluffs_caught": int(result.grade.bluffs_caught),
39
+ }
40
+ )
41
+ return rows
42
+
43
+
44
+ def _summarise(rows: list[dict], episodes_requested: int) -> dict:
45
+ return {
46
+ "episodes_requested": episodes_requested,
47
+ "episodes_completed": len(rows),
48
+ "avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
49
+ "deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
50
+ "avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
51
+ "avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
52
+ "bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
53
+ }
54
+
55
+
56
+ def main() -> None:
57
+ parser = argparse.ArgumentParser(description="Run Gemini self-play baseline")
58
+ parser.add_argument("--episodes", type=int, default=20)
59
+ parser.add_argument("--output", default="results/gemini_baseline.json")
60
+ args = parser.parse_args()
61
+
62
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
63
+ rows = asyncio.run(_run(args.episodes))
64
+ summary = _summarise(rows, args.episodes)
65
+
66
+ out = Path(args.output)
67
+ out.parent.mkdir(parents=True, exist_ok=True)
68
+ out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
69
+ print(json.dumps(summary, indent=2))
70
+ print(f"\nSaved Gemini baseline to {out.resolve()}")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
scripts/validate_sft_data.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Validate JSONL rows against what training/sft_train.py will actually use.
4
+ Read-only; does not run training.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import json
10
+ import statistics
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ # Mirror training/sft_train._extract_completions (do not import to avoid heavy deps at import time)
15
+ def _extract_completions(rec: dict) -> list[str]:
16
+ completion = rec.get("completion")
17
+ if isinstance(completion, str) and completion.strip():
18
+ return [completion.strip()]
19
+
20
+ conversation = rec.get("conversation", [])
21
+ candidates: list[str] = []
22
+ if isinstance(conversation, list):
23
+ for turn in conversation:
24
+ if not isinstance(turn, dict):
25
+ continue
26
+ role = str(turn.get("role", "")).lower()
27
+ content = str(turn.get("content", "")).strip()
28
+ if role == "negotiator" and content:
29
+ candidates.append(content)
30
+ return candidates
31
+
32
+
33
+ def _approx_tokens(text: str) -> float:
34
+ """Rough token estimate without tokenizer (good enough for preflight OOM risk)."""
35
+ if not text:
36
+ return 0.0
37
+ return len(text) / 4.0
38
+
39
+
40
+ def main() -> None:
41
+ parser = argparse.ArgumentParser(description="Validate SFT JSONL against sft_train.py expectations")
42
+ parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL file")
43
+ args = parser.parse_args()
44
+
45
+ path = Path(args.data)
46
+ if not path.is_file():
47
+ print(f"File not found: {path.resolve()}")
48
+ print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
49
+ return
50
+
51
+ usable_rows = 0
52
+ skipped = 0
53
+ prompt_tok: list[float] = []
54
+ completion_tok: list[float] = []
55
+ first_bad_line: int | None = None
56
+
57
+ with path.open("r", encoding="utf-8") as f:
58
+ for line_no, line in enumerate(f, 1):
59
+ line = line.strip()
60
+ if not line:
61
+ continue
62
+ try:
63
+ rec = json.loads(line)
64
+ except json.JSONDecodeError:
65
+ skipped += 1
66
+ if first_bad_line is None:
67
+ first_bad_line = line_no
68
+ continue
69
+ if not isinstance(rec, dict):
70
+ skipped += 1
71
+ continue
72
+
73
+ prompt = str(rec.get("prompt", "")).strip()
74
+ if not prompt:
75
+ skipped += 1
76
+ continue
77
+
78
+ completions = _extract_completions(rec)
79
+ if not completions:
80
+ skipped += 1
81
+ continue
82
+
83
+ usable_rows += 1
84
+ p_t = _approx_tokens(prompt)
85
+ prompt_tok.append(p_t)
86
+ for c in completions:
87
+ completion_tok.append(_approx_tokens(c))
88
+
89
+ sft_trains_one_row_per_completion = "sft_train.py expands one dataset row per negotiator line"
90
+ print("SFT data validator (vs training/sft_train.py load_sft_dataset / _extract_completions)")
91
+ print(f" File: {path.resolve()}")
92
+ print(f" Note: {sft_trains_one_row_per_completion} when 'completion' is absent.")
93
+ print()
94
+ print(f" JSONL records usable (has non-empty 'prompt' and completion or negotiator text): {usable_rows}")
95
+ print(f" Records / rows SKIPPED: {skipped}")
96
+ if first_bad_line is not None:
97
+ print(f" (includes malformed JSONL starting around line {first_bad_line} if any)")
98
+ print()
99
+
100
+ if not prompt_tok and not completion_tok:
101
+ print("No prompt/completion lengths to summarize (all skipped).")
102
+ else:
103
+ def _summary(vals: list[float], label: str) -> None:
104
+ if not vals:
105
+ print(f" {label}: (empty)")
106
+ return
107
+ print(
108
+ f" {label} (approx. tokens, len/4): "
109
+ f"min={min(vals):.1f} max={max(vals):.1f} mean={statistics.mean(vals):.1f} "
110
+ f"std={(statistics.pstdev(vals) if len(vals) > 1 else 0.0):.1f}"
111
+ )
112
+
113
+ _summary(prompt_tok, "Prompt length")
114
+ _summary(completion_tok, "Completion length (each negotiator / completion string)")
115
+ if prompt_tok and statistics.mean(prompt_tok) > 2048.0:
116
+ print(
117
+ " FLAG: Mean prompt length > 2048 (approx. tokens) - may OOM or truncate with "
118
+ "SFTConfig max_seq_length=2048 in sft_train.py on small GPUs."
119
+ )
120
+
121
+ print()
122
+ print(f" {usable_rows} records usable for SFT, {skipped} will be skipped (at record level; negotiator")
123
+ print(" expansion in sft_train can still multiply rows for usable records).")
124
+ if usable_rows < 50:
125
+ print(" WARNING: May be insufficient for SFT. Generate more data first.")
126
+
127
+ if usable_rows == 0:
128
+ sys.exit(1)
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
tests/test_keyless.py CHANGED
@@ -26,7 +26,7 @@ from parlay_env.game_theory import (
26
  )
27
  from parlay_env.grader import compute_step_reward, compute_terminal_reward
28
  from parlay_env.grader import detect_bluff_challenge
29
- from parlay_env.reward import OMEGA, PSI
30
  from parlay_env.models import (
31
  BeliefState, HiddenState, ParlayAction, ParlayState, PersonaType,
32
  )
@@ -200,6 +200,25 @@ class TestGrader:
200
  assert caught is True, f"Expected True, got {caught}"
201
  assert reward >= PSI, f"Expected at least {PSI}, got {reward}"
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def test_zopa_collapse_walkaway_keyless(self):
204
  """Repeated high tension collapses the ZOPA and forces walk-away."""
205
  hidden = HiddenState(
 
26
  )
27
  from parlay_env.grader import compute_step_reward, compute_terminal_reward
28
  from parlay_env.grader import detect_bluff_challenge
29
+ from parlay_env.reward import MU, OMEGA, PSI
30
  from parlay_env.models import (
31
  BeliefState, HiddenState, ParlayAction, ParlayState, PersonaType,
32
  )
 
200
  assert caught is True, f"Expected True, got {caught}"
201
  assert reward >= PSI, f"Expected at least {PSI}, got {reward}"
202
 
203
+ def test_mev_bonus_requires_drift_event(self, parlay_state):
204
+ """MEV bonus only activates when a drift event marker is present."""
205
+ action = ParlayAction(
206
+ utterance="Given that the market shifted, I can adjust.",
207
+ tactical_move=None,
208
+ )
209
+
210
+ no_drift_state = ParlayState(**{**parlay_state.model_dump(), "step_count": 1})
211
+ reward_no_drift = compute_step_reward(parlay_state, action, no_drift_state)
212
+
213
+ with_drift_state = ParlayState(**{**parlay_state.model_dump(), "step_count": 1})
214
+ with_drift_state.__dict__["drift_event"] = "Competitor drops price 15%"
215
+ reward_with_drift = compute_step_reward(parlay_state, action, with_drift_state)
216
+
217
+ assert reward_with_drift >= reward_no_drift + MU, (
218
+ f"Expected drift reward boost >= {MU}, got "
219
+ f"{reward_with_drift - reward_no_drift}"
220
+ )
221
+
222
  def test_zopa_collapse_walkaway_keyless(self):
223
  """Repeated high tension collapses the ZOPA and forces walk-away."""
224
  hidden = HiddenState(
training/generate_data.py CHANGED
@@ -33,6 +33,32 @@ REQUIRED_COMBINATIONS = [
33
  for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
34
  ]
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def is_quality_episode(grade, args) -> tuple[bool, str]:
38
  """
@@ -308,7 +334,8 @@ async def run_inspect_mode(args) -> None:
308
 
309
  async def run_diversity_pass(args, output_path: Path) -> None:
310
  """
311
- Generate a quality-filtered dataset with guaranteed persona x scenario coverage.
 
312
  """
313
  output_path.parent.mkdir(parents=True, exist_ok=True)
314
  coverage: dict[tuple[str, str], int] = defaultdict(int)
@@ -316,169 +343,111 @@ async def run_diversity_pass(args, output_path: Path) -> None:
316
  kept_records: list[dict] = []
317
  generated = 0
318
  discarded = 0
 
319
  seed = 0
320
- min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
321
  total_live_calls: int = 0
322
  total_fallback_calls: int = 0
323
  _verbose = not getattr(args, "quiet", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  with open(output_path, "w", encoding="utf-8") as out_f:
326
  while len(kept_records) < args.episodes:
327
- progress_made = False
328
- for persona, scenario_id in REQUIRED_COMBINATIONS:
329
- if len(kept_records) >= args.episodes:
330
- break
331
- if coverage[(persona, scenario_id)] >= min_per_combo:
332
- continue
333
-
334
- record = await _run_one(persona, scenario_id, seed=seed, max_turns=args.max_turns)
335
- seed += 1
336
- generated += 1
337
- if record is None:
338
- _live_n, _fall_n = get_and_reset_counts()
339
- total_live_calls += _live_n
340
- total_fallback_calls += _fall_n
341
- continue
342
-
343
- keep, reason = is_quality_episode(
344
- _grade_proxy_from_record(record),
345
- args,
346
- )
347
- if not keep:
348
- discarded += 1
349
- _live_d, _fall_d = get_and_reset_counts()
350
- total_live_calls += _live_d
351
- total_fallback_calls += _fall_d
352
- if _verbose:
353
- print(
354
- f"[EP --/{args.episodes:02d}] "
355
- f"{persona}×{scenario_id:<27s} | "
356
- f"reward={record.get('reward', 0.0):+.2f} | "
357
- f"eff={record.get('deal_efficiency', 0.0):.3f} | "
358
- f"kept=NO | "
359
- f"total_kept={len(kept_records)}/{generated} | "
360
- f"gemini_live={_live_d} fallback={_fall_d}",
361
- file=sys.stderr,
362
- )
363
- continue
364
-
365
- out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
366
- kept_records.append(record)
367
- _live, _fall = get_and_reset_counts()
368
- total_live_calls += _live
369
- total_fallback_calls += _fall
370
- _ep_num = len(kept_records)
371
  if _verbose:
372
- _reward = record.get("reward", 0.0)
373
- _eff = record.get("deal_efficiency", 0.0)
374
- _combo = f"{record['persona']}×{record['scenario_id']}"
375
  print(
376
- f"[EP {_ep_num:02d}/{args.episodes:02d}] "
377
- f"{_combo:<35s} | "
378
- f"reward={_reward:+.2f} | "
379
- f"eff={_eff:.3f} | "
380
- f"kept=YES | "
381
- f"total_kept={_ep_num}/{generated} | "
382
- f"gemini_live={_live} fallback={_fall}",
383
  file=sys.stderr,
384
  )
385
- if _ep_num in (20, 40, 60):
386
- _all_rewards = [r.get("reward", 0.0) for r in kept_records]
387
- _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
388
- _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
389
- print(f"\n{'━' * 40}", file=sys.stderr)
390
- print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
391
- print(
392
- f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
393
- file=sys.stderr,
394
- )
395
- print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
396
- print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
397
- print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
398
- print(f" Live calls total: {total_live_calls}", file=sys.stderr)
399
- print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
400
- print(f"{'' * 40}\n", file=sys.stderr)
401
- coverage[(persona, scenario_id)] += 1
402
- kept_reason_counts[reason] += 1
403
- progress_made = True
404
-
405
- if len(kept_records) >= args.episodes:
406
- break
407
-
408
- if not progress_made:
409
- persona, scenario_id = random.choice(REQUIRED_COMBINATIONS)
410
- record = await _run_one(persona, scenario_id, seed=seed, max_turns=args.max_turns)
411
- seed += 1
412
- generated += 1
413
- if record is None:
414
- _live_n, _fall_n = get_and_reset_counts()
415
- total_live_calls += _live_n
416
- total_fallback_calls += _fall_n
417
- continue
418
- keep, reason = is_quality_episode(
419
- _grade_proxy_from_record(record),
420
- args,
 
 
 
 
 
 
 
 
421
  )
422
- if keep:
423
- out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
424
- kept_records.append(record)
425
- _live, _fall = get_and_reset_counts()
426
- total_live_calls += _live
427
- total_fallback_calls += _fall
428
- _ep_num = len(kept_records)
429
- if _verbose:
430
- _reward = record.get("reward", 0.0)
431
- _eff = record.get("deal_efficiency", 0.0)
432
- _combo = f"{record['persona']}×{record['scenario_id']}"
433
- print(
434
- f"[EP {_ep_num:02d}/{args.episodes:02d}] "
435
- f"{_combo:<35s} | "
436
- f"reward={_reward:+.2f} | "
437
- f"eff={_eff:.3f} | "
438
- f"kept=YES | "
439
- f"total_kept={_ep_num}/{generated} | "
440
- f"gemini_live={_live} fallback={_fall}",
441
- file=sys.stderr,
442
- )
443
- if _ep_num in (20, 40, 60):
444
- _all_rewards = [r.get("reward", 0.0) for r in kept_records]
445
- _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
446
- _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
447
- print(f"\n{'━' * 40}", file=sys.stderr)
448
- print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
449
- print(
450
- f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
451
- file=sys.stderr,
452
- )
453
- print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
454
- print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
455
- print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
456
- print(f" Live calls total: {total_live_calls}", file=sys.stderr)
457
- print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
458
- print(f"{'━' * 40}\n", file=sys.stderr)
459
- coverage[(persona, scenario_id)] += 1
460
- kept_reason_counts[reason] += 1
461
- else:
462
- discarded += 1
463
- _live_d, _fall_d = get_and_reset_counts()
464
- total_live_calls += _live_d
465
- total_fallback_calls += _fall_d
466
- if _verbose:
467
- print(
468
- f"[EP --/{args.episodes:02d}] "
469
- f"{persona}×{scenario_id:<27s} | "
470
- f"reward={record.get('reward', 0.0):+.2f} | "
471
- f"eff={record.get('deal_efficiency', 0.0):.3f} | "
472
- f"kept=NO | "
473
- f"total_kept={len(kept_records)}/{generated} | "
474
- f"gemini_live={_live_d} fallback={_fall_d}",
475
- file=sys.stderr,
476
- )
477
 
478
  discard_pct = (discarded / max(generated, 1)) * 100.0
479
  print(
480
  f"Generated: {generated} episodes | Kept: {len(kept_records)} | "
481
- f"Discarded: {discarded} ({discard_pct:.0f}%)"
 
482
  )
483
  reasons_str = ", ".join(f"{reason}={count}" for reason, count in sorted(kept_reason_counts.items()))
484
  print(f"Reasons kept: {reasons_str or 'none'}")
@@ -488,9 +457,9 @@ async def run_diversity_pass(args, output_path: Path) -> None:
488
 
489
  _fallback_rate = 100.0 * total_fallback_calls / max(total_live_calls + total_fallback_calls, 1)
490
  _verdict = (
491
- "ALL CALLS LIVE data is real"
492
  if _fallback_rate < 5.0
493
- else "WARNING: fallback rate high check API key and rate limits"
494
  )
495
  print(f"\nGemini API health:")
496
  print(f" Total live calls : {total_live_calls}")
@@ -501,8 +470,14 @@ async def run_diversity_pass(args, output_path: Path) -> None:
501
 
502
  def main() -> None:
503
  parser = argparse.ArgumentParser(description="Generate Parlay training data")
504
- parser.add_argument("--episodes", type=int, default=80)
505
  parser.add_argument("--output", type=str, default="data/episodes.jsonl")
 
 
 
 
 
 
506
  parser.add_argument(
507
  "--quality_filter",
508
  action="store_true",
 
33
  for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
34
  ]
35
 
36
+ # Weighted to oversample historically low deal-rate combinations (total weight = 15)
37
+ COMBO_WEIGHTS: dict[tuple[str, str], int] = {
38
+ ("veteran", "hiring_package"): 3,
39
+ ("veteran", "saas_enterprise"): 2,
40
+ ("veteran", "acquisition_term_sheet"): 2,
41
+ ("shark", "hiring_package"): 2,
42
+ ("diplomat", "hiring_package"): 2,
43
+ ("shark", "saas_enterprise"): 1,
44
+ ("shark", "acquisition_term_sheet"): 1,
45
+ ("diplomat", "saas_enterprise"): 1,
46
+ ("diplomat", "acquisition_term_sheet"): 1,
47
+ }
48
+ WEIGHTED_COMBO_LIST: list[tuple[str, str]] = []
49
+ for _pair, _weight in COMBO_WEIGHTS.items():
50
+ WEIGHTED_COMBO_LIST.extend([_pair] * _weight)
51
+
52
+
53
+ def _row_total_reward(record: dict) -> float | None:
54
+ v = record.get("reward")
55
+ if v is not None:
56
+ return float(v)
57
+ v2 = record.get("cumulative_reward")
58
+ if v2 is not None:
59
+ return float(v2)
60
+ return None
61
+
62
 
63
  def is_quality_episode(grade, args) -> tuple[bool, str]:
64
  """
 
334
 
335
  async def run_diversity_pass(args, output_path: Path) -> None:
336
  """
337
+ Generate a quality-filtered dataset; persona x scenario is weighted-sampled
338
+ (see COMBO_WEIGHTS / WEIGHTED_COMBO_LIST).
339
  """
340
  output_path.parent.mkdir(parents=True, exist_ok=True)
341
  coverage: dict[tuple[str, str], int] = defaultdict(int)
 
343
  kept_records: list[dict] = []
344
  generated = 0
345
  discarded = 0
346
+ skipped_min_reward = 0
347
  seed = 0
 
348
  total_live_calls: int = 0
349
  total_fallback_calls: int = 0
350
  _verbose = not getattr(args, "quiet", False)
351
+ _checkpoints = {20, 40, 60, 80, 100, 120, 140}
352
+
353
+ def _emit_checkpoint(_ep_num: int) -> None:
354
+ if not _verbose or _ep_num not in _checkpoints:
355
+ return
356
+ _all_rewards = [r.get("reward", 0.0) for r in kept_records]
357
+ _all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
358
+ _combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
359
+ print(f"\n{'━' * 40}", file=sys.stderr)
360
+ print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
361
+ print(
362
+ f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
363
+ file=sys.stderr,
364
+ )
365
+ print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
366
+ print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
367
+ print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
368
+ print(f" Min-reward skip : {skipped_min_reward}", file=sys.stderr)
369
+ print(f" Live calls total: {total_live_calls}", file=sys.stderr)
370
+ print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
371
+ print(f"{'━' * 40}\n", file=sys.stderr)
372
 
373
  with open(output_path, "w", encoding="utf-8") as out_f:
374
  while len(kept_records) < args.episodes:
375
+ persona, scenario_id = random.choice(WEIGHTED_COMBO_LIST)
376
+ record = await _run_one(persona, scenario_id, seed=seed, max_turns=args.max_turns)
377
+ seed += 1
378
+ generated += 1
379
+ if record is None:
380
+ _live_n, _fall_n = get_and_reset_counts()
381
+ total_live_calls += _live_n
382
+ total_fallback_calls += _fall_n
383
+ continue
384
+
385
+ rw = _row_total_reward(record)
386
+ if rw is not None and rw < args.min_reward:
387
+ skipped_min_reward += 1
388
+ _live_m, _fall_m = get_and_reset_counts()
389
+ total_live_calls += _live_m
390
+ total_fallback_calls += _fall_m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  if _verbose:
 
 
 
392
  print(
393
+ f"[min_reward skip #{skipped_min_reward}] {persona} x {scenario_id} "
394
+ f"reward={rw:.2f} < {args.min_reward}",
 
 
 
 
 
395
  file=sys.stderr,
396
  )
397
+ continue
398
+
399
+ keep, reason = is_quality_episode(
400
+ _grade_proxy_from_record(record),
401
+ args,
402
+ )
403
+ if not keep:
404
+ discarded += 1
405
+ _live_d, _fall_d = get_and_reset_counts()
406
+ total_live_calls += _live_d
407
+ total_fallback_calls += _fall_d
408
+ if _verbose:
409
+ print(
410
+ f"[EP --/{args.episodes:02d}] "
411
+ f"{persona}×{scenario_id:<27s} | "
412
+ f"reward={record.get('reward', 0.0):+.2f} | "
413
+ f"eff={record.get('deal_efficiency', 0.0):.3f} | "
414
+ f"kept=NO | "
415
+ f"total_kept={len(kept_records)}/{generated} | "
416
+ f"gemini_live={_live_d} fallback={_fall_d}",
417
+ file=sys.stderr,
418
+ )
419
+ continue
420
+
421
+ out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
422
+ out_f.flush()
423
+ kept_records.append(record)
424
+ _live, _fall = get_and_reset_counts()
425
+ total_live_calls += _live
426
+ total_fallback_calls += _fall
427
+ _ep_num = len(kept_records)
428
+ if _verbose:
429
+ _reward = record.get("reward", 0.0)
430
+ _eff = record.get("deal_efficiency", 0.0)
431
+ _combo = f"{record['persona']}×{record['scenario_id']}"
432
+ print(
433
+ f"[EP {_ep_num:02d}/{args.episodes:02d}] "
434
+ f"{_combo:<35s} | "
435
+ f"reward={_reward:+.2f} | "
436
+ f"eff={_eff:.3f} | "
437
+ f"kept=YES | "
438
+ f"total_kept={_ep_num}/{generated} | "
439
+ f"gemini_live={_live} fallback={_fall}",
440
+ file=sys.stderr,
441
  )
442
+ _emit_checkpoint(_ep_num)
443
+ coverage[(persona, scenario_id)] += 1
444
+ kept_reason_counts[reason] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  discard_pct = (discarded / max(generated, 1)) * 100.0
447
  print(
448
  f"Generated: {generated} episodes | Kept: {len(kept_records)} | "
449
+ f"Discarded: {discarded} ({discard_pct:.0f}%) | "
450
+ f"Skipped (min_reward < {args.min_reward}): {skipped_min_reward}"
451
  )
452
  reasons_str = ", ".join(f"{reason}={count}" for reason, count in sorted(kept_reason_counts.items()))
453
  print(f"Reasons kept: {reasons_str or 'none'}")
 
457
 
458
  _fallback_rate = 100.0 * total_fallback_calls / max(total_live_calls + total_fallback_calls, 1)
459
  _verdict = (
460
+ "ALL CALLS LIVE - data is real"
461
  if _fallback_rate < 5.0
462
+ else "WARNING: fallback rate high - check API key and rate limits"
463
  )
464
  print(f"\nGemini API health:")
465
  print(f" Total live calls : {total_live_calls}")
 
470
 
471
  def main() -> None:
472
  parser = argparse.ArgumentParser(description="Generate Parlay training data")
473
+ parser.add_argument("--episodes", type=int, default=140)
474
  parser.add_argument("--output", type=str, default="data/episodes.jsonl")
475
+ parser.add_argument(
476
+ "--min-reward",
477
+ type=float,
478
+ default=-50.0,
479
+ help="After grading, do not write episodes with total reward below this (default: -50.0)",
480
+ )
481
  parser.add_argument(
482
  "--quality_filter",
483
  action="store_true",
training/grpo_train.py CHANGED
@@ -18,18 +18,32 @@ from pathlib import Path
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
- BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
 
 
22
  GRPO_STEPS = int(os.getenv("GRPO_STEPS", "500"))
23
  GRPO_GENERATIONS = int(os.getenv("GRPO_GENERATIONS", "8"))
24
 
25
 
26
- def build_grpo_dataset(jsonl_path: str):
 
 
 
 
 
 
 
 
 
 
27
  """
28
  Build GRPO dataset. Each record needs only a 'prompt' field plus metadata.
29
  The model generates G=8 completions per prompt; grader scores all 8.
30
 
31
  Args:
32
  jsonl_path: Path to the JSONL episodes file.
 
 
33
 
34
  Returns:
35
  HuggingFace Dataset with prompt + metadata columns.
@@ -40,21 +54,31 @@ def build_grpo_dataset(jsonl_path: str):
40
  raise ImportError("Install datasets: pip install datasets") from exc
41
 
42
  prompts = []
 
43
  with open(jsonl_path, encoding="utf-8") as f:
44
  for line in f:
45
  rec = json.loads(line.strip())
46
- if rec.get("split") == "train":
47
- # Extract ZOPA metadata for reward functions
48
- prompts.append({
 
 
 
 
 
 
49
  "prompt": rec["prompt"],
50
  "scenario_id": rec.get("scenario_id", ""),
51
  "persona": rec.get("persona", ""),
52
- # Reward function kwargs (passed through dataset)
53
  "batna_seller": _get_batna(rec.get("scenario_id", ""), "seller"),
54
- "batna_buyer": _get_batna(rec.get("scenario_id", ""), "buyer"),
55
- "zopa_width": _get_zopa_width(rec.get("scenario_id", "")),
56
- })
57
- logger.info(f"GRPO dataset: {len(prompts)} prompts")
 
 
 
 
58
  return Dataset.from_list(prompts)
59
 
60
 
@@ -62,7 +86,7 @@ def _get_batna(scenario_id: str, side: str) -> float:
62
  """Lookup BATNA for a scenario without importing game module at training time."""
63
  batnas: dict[str, dict[str, float]] = {
64
  "saas_enterprise": {"seller": 125_000, "buyer": 165_000},
65
- "hiring_package": {"seller": 195_000, "buyer": 230_000},
66
  "acquisition_term_sheet": {"seller": 10_500_000, "buyer": 16_000_000},
67
  }
68
  return float(batnas.get(scenario_id, {}).get(side, 0))
@@ -80,6 +104,7 @@ def train_grpo(
80
  data_path: str,
81
  output_dir: str,
82
  steps: int = 500,
 
83
  ) -> None:
84
  """
85
  GRPO training loop.
@@ -117,7 +142,7 @@ def train_grpo(
117
  format_reward,
118
  )
119
 
120
- dataset = build_grpo_dataset(data_path)
121
  if len(dataset) == 0:
122
  raise ValueError("Empty GRPO dataset. Run generate_data.py first.")
123
 
@@ -181,6 +206,12 @@ def main() -> None:
181
  parser.add_argument("--model", default="models/parlay-sft")
182
  parser.add_argument("--base_model", default="")
183
  parser.add_argument("--data", default="data/episodes.jsonl")
 
 
 
 
 
 
184
  parser.add_argument("--output", default="models/parlay-grpo")
185
  parser.add_argument("--steps", type=int, default=GRPO_STEPS)
186
  parser.add_argument("--g", type=int, default=GRPO_GENERATIONS)
@@ -191,7 +222,7 @@ def main() -> None:
191
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
192
  GRPO_GENERATIONS = args.g
193
  model_path = args.base_model or args.model
194
- train_grpo(model_path, args.data, args.output, args.steps)
195
 
196
  if args.save_curves:
197
  curves_path = Path(args.save_curves)
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
+ # SFT->GRPO pipeline: set BASE_MODEL=checkpoints/sft_1.5b/ after sft_train.py
22
+ # (overridable via BASE_MODEL env var)
23
+ BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
24
  GRPO_STEPS = int(os.getenv("GRPO_STEPS", "500"))
25
  GRPO_GENERATIONS = int(os.getenv("GRPO_GENERATIONS", "8"))
26
 
27
 
28
+ def _row_total_reward(rec: dict) -> float | None:
29
+ v = rec.get("reward")
30
+ if v is not None:
31
+ return float(v)
32
+ v2 = rec.get("cumulative_reward")
33
+ if v2 is not None:
34
+ return float(v2)
35
+ return None
36
+
37
+
38
+ def build_grpo_dataset(jsonl_path: str, min_reward: float = -50.0):
39
  """
40
  Build GRPO dataset. Each record needs only a 'prompt' field plus metadata.
41
  The model generates G=8 completions per prompt; grader scores all 8.
42
 
43
  Args:
44
  jsonl_path: Path to the JSONL episodes file.
45
+ min_reward: Drop train rows with total reward (reward / cumulative_reward) below this
46
+ (missing reward fields are kept for backward compatibility).
47
 
48
  Returns:
49
  HuggingFace Dataset with prompt + metadata columns.
 
54
  raise ImportError("Install datasets: pip install datasets") from exc
55
 
56
  prompts = []
57
+ filtered = 0
58
  with open(jsonl_path, encoding="utf-8") as f:
59
  for line in f:
60
  rec = json.loads(line.strip())
61
+ if rec.get("split") != "train":
62
+ continue
63
+ r = _row_total_reward(rec)
64
+ if r is not None and r < min_reward:
65
+ filtered += 1
66
+ continue
67
+ # Extract ZOPA metadata for reward functions
68
+ prompts.append(
69
+ {
70
  "prompt": rec["prompt"],
71
  "scenario_id": rec.get("scenario_id", ""),
72
  "persona": rec.get("persona", ""),
 
73
  "batna_seller": _get_batna(rec.get("scenario_id", ""), "seller"),
74
+ "batna_buyer": _get_batna(rec.get("scenario_id", ""), "buyer"),
75
+ "zopa_width": _get_zopa_width(rec.get("scenario_id", "")),
76
+ }
77
+ )
78
+ print(
79
+ f"Filtered {filtered} records below min_reward={min_reward}, "
80
+ f"{len(prompts)} remaining for GRPO"
81
+ )
82
  return Dataset.from_list(prompts)
83
 
84
 
 
86
  """Lookup BATNA for a scenario without importing game module at training time."""
87
  batnas: dict[str, dict[str, float]] = {
88
  "saas_enterprise": {"seller": 125_000, "buyer": 165_000},
89
+ "hiring_package": {"seller": 195_000, "buyer": 264_500}, # match game/scenarios (widened zopa)
90
  "acquisition_term_sheet": {"seller": 10_500_000, "buyer": 16_000_000},
91
  }
92
  return float(batnas.get(scenario_id, {}).get(side, 0))
 
104
  data_path: str,
105
  output_dir: str,
106
  steps: int = 500,
107
+ min_reward: float = -50.0,
108
  ) -> None:
109
  """
110
  GRPO training loop.
 
142
  format_reward,
143
  )
144
 
145
+ dataset = build_grpo_dataset(data_path, min_reward=min_reward)
146
  if len(dataset) == 0:
147
  raise ValueError("Empty GRPO dataset. Run generate_data.py first.")
148
 
 
206
  parser.add_argument("--model", default="models/parlay-sft")
207
  parser.add_argument("--base_model", default="")
208
  parser.add_argument("--data", default="data/episodes.jsonl")
209
+ parser.add_argument(
210
+ "--min-reward",
211
+ type=float,
212
+ default=-50.0,
213
+ help="Skip JSONL train rows with total reward below this (default: -50.0)",
214
+ )
215
  parser.add_argument("--output", default="models/parlay-grpo")
216
  parser.add_argument("--steps", type=int, default=GRPO_STEPS)
217
  parser.add_argument("--g", type=int, default=GRPO_GENERATIONS)
 
222
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
223
  GRPO_GENERATIONS = args.g
224
  model_path = args.base_model or args.model
225
+ train_grpo(model_path, args.data, args.output, args.steps, min_reward=args.min_reward)
226
 
227
  if args.save_curves:
228
  curves_path = Path(args.save_curves)
training/random_baseline.py CHANGED
@@ -1,125 +1,138 @@
1
- """
2
- Random-policy baseline for Parlay.
3
- Runs N episodes with purely random move selection (no Gemini API — always
4
- uses mock mode) and writes a summary JSON that the training notebook uses
5
- to benchmark SFT / GRPO improvement.
6
-
7
- Usage:
8
- python training/random_baseline.py
9
- python training/random_baseline.py --episodes 20 --output data/random_baseline.json
10
- """
11
  import argparse
12
  import asyncio
13
  import json
14
  import logging
15
- import os
16
  import random
17
- import statistics
18
- import sys
19
  from pathlib import Path
20
 
21
- # Repo root on sys.path when run as `python training/random_baseline.py`
22
- sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
23
-
24
- # Force mock mode — random baseline never calls the real Gemini API
25
- os.environ.pop("GOOGLE_API_KEY", None)
26
-
27
- from agent.runner import run_episode
28
- from game.scenarios import SCENARIOS
29
- from parlay_env.models import PersonaType
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
- REQUIRED_COMBINATIONS = [
34
- (persona, scenario)
35
- for persona in ["shark", "diplomat", "veteran"]
36
- for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
 
 
 
 
37
  ]
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  async def _run_baseline(episodes: int) -> list[dict]:
41
- """Run `episodes` random-policy episodes and return per-episode stats."""
42
- results = []
43
- seed = 0
44
  for i in range(episodes):
45
- persona_str, scenario_id = REQUIRED_COMBINATIONS[i % len(REQUIRED_COMBINATIONS)]
 
46
  try:
47
- result = await run_episode(
48
- persona=PersonaType(persona_str),
49
- scenario_id=scenario_id,
50
- inject_noise=True, # random noise simulates random policy
51
- force_drift=random.random() < 0.4,
52
- seed=seed,
53
- max_turns=14,
54
- )
55
- results.append({
56
- "persona": persona_str,
57
- "scenario_id": scenario_id,
58
- "reward": result.grade.total_reward,
59
- "deal_efficiency": result.grade.deal_efficiency,
60
- "deal_reached": result.final_price is not None,
61
- "tom_accuracy_avg": result.grade.tom_accuracy_avg,
62
- "drift_adapted": result.grade.drift_adapted,
63
- "termination_reason": result.grade.termination_reason,
64
- })
65
  except Exception as exc:
66
- logger.warning("Baseline episode %d failed (%s, %s): %s", i, persona_str, scenario_id, exc)
67
- seed += 1
68
- return results
69
-
70
-
71
- def _summarise(results: list[dict]) -> dict:
72
- if not results:
73
- return {"error": "no episodes completed", "n_episodes": 0}
74
-
75
- rewards = [r["reward"] for r in results]
76
- efficiencies = [r["deal_efficiency"] for r in results]
77
- deal_count = sum(1 for r in results if r["deal_reached"])
78
- drift_count = sum(1 for r in results if r["drift_adapted"])
79
-
 
80
  return {
81
- "n_episodes": len(results),
82
- "mean_reward": round(statistics.mean(rewards), 3),
83
- "std_reward": round(statistics.stdev(rewards) if len(rewards) > 1 else 0.0, 3),
84
- "min_reward": round(min(rewards), 3),
85
- "max_reward": round(max(rewards), 3),
86
- "mean_efficiency": round(statistics.mean(efficiencies), 4),
87
- "deal_rate": round(deal_count / len(results), 4),
88
- "drift_adapted_rate": round(drift_count / len(results), 4),
89
- "policy": "random_mock",
90
- "note": (
91
- "Baseline uses Parlay mock responses (no real Gemini API). "
92
- "Compare mean_reward and mean_efficiency against SFT/GRPO runs."
93
- ),
94
  }
95
 
96
 
97
  def main() -> None:
98
- parser = argparse.ArgumentParser(description="Parlay random-policy baseline")
99
- parser.add_argument("--episodes", type=int, default=27,
100
- help="Number of baseline episodes (default: 27 = 3 per combo)")
101
- parser.add_argument("--output", type=str, default="data/random_baseline.json",
102
- help="Output path for the baseline JSON summary")
103
  args = parser.parse_args()
104
 
105
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
106
- logging.getLogger("httpx").setLevel(logging.WARNING)
107
-
108
- print(f"Running {args.episodes} random-policy episodes (mock mode, no API key)…")
109
- results = asyncio.run(_run_baseline(args.episodes))
110
-
111
- summary = _summarise(results)
112
 
113
  out_path = Path(args.output)
114
  out_path.parent.mkdir(parents=True, exist_ok=True)
115
- with open(out_path, "w", encoding="utf-8") as f:
116
- json.dump(summary, f, indent=2)
117
-
118
- print(f"\nBaseline complete ({summary['n_episodes']} episodes):")
119
- print(f" Mean reward : {summary.get('mean_reward', 'N/A')}")
120
- print(f" Mean efficiency : {summary.get('mean_efficiency', 'N/A')}")
121
- print(f" Deal rate : {summary.get('deal_rate', 'N/A'):.1%}")
122
- print(f" Written to : {out_path.resolve()}")
123
 
124
 
125
  if __name__ == "__main__":
 
1
+ """Random-action baseline for Parlay."""
 
 
 
 
 
 
 
 
 
2
  import argparse
3
  import asyncio
4
  import json
5
  import logging
 
6
  import random
 
 
7
  from pathlib import Path
8
 
9
+ from parlay_env.grader import grade_episode
10
+ from parlay_env.models import TacticalMove
11
+ from parlay_env.server import _handle_reset, _handle_step, get_session_state
 
 
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ PERSONAS = ["shark", "diplomat", "veteran"]
16
+ SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
17
+ RANDOM_LINES = [
18
+ "Let's keep talking.",
19
+ "I can move a bit.",
20
+ "This is my proposal.",
21
+ "We should find middle ground.",
22
+ "Given that context, here's my number.",
23
  ]
24
 
25
 
26
+ def _mean(values: list[float]) -> float:
27
+ return sum(values) / len(values) if values else 0.0
28
+
29
+
30
+ async def _run_single_episode(scenario_id: str, persona: str, seed: int) -> dict:
31
+ random.seed(seed)
32
+ reset = await _handle_reset({"scenario_id": scenario_id, "persona": persona, "seed": seed})
33
+ session_id = str(reset["session_id"])
34
+ final_price = None
35
+ t_close = None
36
+ done = False
37
+
38
+ while not done:
39
+ state = get_session_state(session_id)
40
+ if state is None:
41
+ break
42
+ if state.episode_done:
43
+ break
44
+
45
+ low = state.hidden_state.walk_away_price
46
+ high = state.hidden_state.budget_ceiling
47
+ offer = round(random.uniform(low, high), 2)
48
+
49
+ moves: list[TacticalMove | None] = [None]
50
+ if state.credibility_points >= 0:
51
+ moves.append(TacticalMove.ANCHOR_HIGH)
52
+ if state.credibility_points >= 5:
53
+ moves.append(TacticalMove.SILENCE)
54
+ if state.credibility_points >= 20:
55
+ moves.append(TacticalMove.BATNA_REVEAL)
56
+ move = random.choice(moves)
57
+
58
+ payload = {
59
+ "session_id": session_id,
60
+ "action": {
61
+ "utterance": random.choice(RANDOM_LINES),
62
+ "offer_amount": offer,
63
+ "tactical_move": move.value if move else None,
64
+ },
65
+ }
66
+ step = await _handle_step(payload)
67
+ done = bool(step.get("done", False))
68
+
69
+ state = get_session_state(session_id)
70
+ if state and state.deal_reached and final_price is None:
71
+ final_price = offer
72
+ t_close = state.step_count
73
+
74
+ state = get_session_state(session_id)
75
+ if state is None:
76
+ raise RuntimeError(f"Missing session state for {session_id}")
77
+ grade = grade_episode(state, final_price=final_price, t_close=t_close, t_max=20)
78
+ return {
79
+ "avg_reward": float(grade.total_reward),
80
+ "deal_rate": 1.0 if final_price is not None else 0.0,
81
+ "avg_efficiency": float(grade.deal_efficiency),
82
+ "avg_tom_accuracy": float(grade.tom_accuracy_avg),
83
+ "bluffs_caught": int(grade.bluffs_caught),
84
+ }
85
+
86
+
87
  async def _run_baseline(episodes: int) -> list[dict]:
88
+ rows: list[dict] = []
 
 
89
  for i in range(episodes):
90
+ persona = PERSONAS[i % len(PERSONAS)]
91
+ scenario = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
92
  try:
93
+ rows.append(await _run_single_episode(scenario, persona, i + 7))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as exc:
95
+ logger.warning("Baseline episode %d failed (%s/%s): %s", i + 1, scenario, persona, exc)
96
+ return rows
97
+
98
+
99
+ def _summarise(rows: list[dict], episodes_requested: int) -> dict:
100
+ if not rows:
101
+ return {
102
+ "episodes_requested": episodes_requested,
103
+ "episodes_completed": 0,
104
+ "avg_reward": 0.0,
105
+ "deal_rate": 0.0,
106
+ "avg_efficiency": 0.0,
107
+ "avg_tom_accuracy": 0.0,
108
+ "bluffs_caught": 0,
109
+ }
110
  return {
111
+ "episodes_requested": episodes_requested,
112
+ "episodes_completed": len(rows),
113
+ "avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
114
+ "deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
115
+ "avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
116
+ "avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
117
+ "bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
 
 
 
 
 
 
118
  }
119
 
120
 
121
  def main() -> None:
122
+ parser = argparse.ArgumentParser(description="Parlay random baseline")
123
+ parser.add_argument("--episodes", type=int, default=20)
124
+ parser.add_argument("--output", default="results/baseline.json")
 
 
125
  args = parser.parse_args()
126
 
127
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
128
+ rows = asyncio.run(_run_baseline(args.episodes))
129
+ summary = _summarise(rows, args.episodes)
 
 
 
 
130
 
131
  out_path = Path(args.output)
132
  out_path.parent.mkdir(parents=True, exist_ok=True)
133
+ out_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
134
+ print(json.dumps(summary, indent=2))
135
+ print(f"\nSaved random baseline to {out_path.resolve()}")
 
 
 
 
 
136
 
137
 
138
  if __name__ == "__main__":
training/sft_train.py CHANGED
@@ -1,177 +1,119 @@
1
  """
2
- Stage 1: SFT warmup on best episodes (efficiency >= threshold).
3
- Fine-tunes Qwen2.5-7B-Instruct on demonstrations of successful negotiation.
4
-
5
- Applies episode quality filters (offers + reward outliers) and stable SFT target
6
- metadata (log-scaled efficiency, clipped reward) when building training text.
7
-
8
- Usage:
9
- python -m training.sft_train \
10
- --data data/episodes.jsonl \
11
- --model Qwen/Qwen2.5-7B-Instruct \
12
- --output models/parlay-sft \
13
- --threshold 0.30
14
  """
15
  import argparse
16
  import json
17
  import logging
18
- import os
19
  from pathlib import Path
20
 
21
- from .episode_filters import (
22
- SFTFilterConfig,
23
- clip_reward_for_label,
24
- efficiency_sft_label,
25
- episode_passes_sft_filters,
26
- )
27
-
28
  logger = logging.getLogger(__name__)
29
 
30
- TOP_PLAYER_THRESHOLD = float(os.getenv("TOP_PLAYER_THRESHOLD", "0.30"))
31
- BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-7B-Instruct")
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def load_sft_dataset(
35
- jsonl_path: Path,
36
- threshold: float = 0.30,
37
- filter_cfg: SFTFilterConfig | None = None,
38
- include_sft_targets: bool = True,
39
- ):
40
- """
41
- Load episodes above efficiency threshold and format for SFT.
42
 
43
- Only 'train' split episodes above the threshold are included.
44
- Rows failing quality filters (broken offers, extreme rewards) are skipped.
45
- Each agent turn becomes one training example.
 
 
 
 
 
46
 
47
- Args:
48
- jsonl_path: Path to the JSONL episodes file.
49
- threshold: Minimum deal_efficiency to include.
50
- filter_cfg: Drop/clip thresholds; default SFTFilterConfig().
51
- include_sft_targets: If True, embed eff_log and reward_clip in the example text.
52
 
53
- Returns:
54
- HuggingFace Dataset with 'text' column.
55
- """
56
  try:
57
  from datasets import Dataset
58
  except ImportError as exc:
59
  raise ImportError("Install datasets: pip install datasets") from exc
60
 
61
- filter_cfg = filter_cfg or SFTFilterConfig()
62
- records = []
63
- skipped_filter = 0
64
- with open(jsonl_path, encoding="utf-8") as f:
65
- for line in f:
66
- rec = json.loads(line.strip())
67
- ok, _reason = episode_passes_sft_filters(rec, filter_cfg)
68
- if not ok:
69
- skipped_filter += 1
70
  continue
71
- if rec.get("deal_efficiency", 0) >= threshold and rec.get("split") == "train":
72
- conversation = rec.get("conversation", [])
73
- eff_l = efficiency_sft_label(rec.get("deal_efficiency"))
74
- r_clip = clip_reward_for_label(rec.get("reward"), filter_cfg)
75
- for i, turn in enumerate(conversation[:-1]):
76
- if turn.get("role") == "negotiator":
77
- context = conversation[:i]
78
- records.append({
79
- "text": _format_sft_example(
80
- system=rec["prompt"],
81
- context=context,
82
- response=turn["content"],
83
- efficiency_label=eff_l,
84
- reward_clip=r_clip,
85
- include_sft_targets=include_sft_targets,
86
- )
87
- })
88
-
89
- logger.info(
90
- f"SFT dataset: {len(records)} training examples from {jsonl_path} "
91
- f"(skipped {skipped_filter} episodes by quality filter)"
92
- )
93
- return Dataset.from_list(records)
94
-
95
-
96
- def _format_sft_example(
97
- system: str,
98
- context: list[dict],
99
- response: str,
100
- efficiency_label: float,
101
- reward_clip: float,
102
- include_sft_targets: bool = True,
103
- ) -> str:
104
- """Format a single SFT training example in chat template format."""
105
- history_lines = []
106
- for turn in context:
107
- role = turn.get("role", "unknown").upper()
108
- content = turn.get("content", "")
109
- history_lines.append(f"{role}: {content}")
110
- history = "\n".join(history_lines)
111
-
112
- targets_block = ""
113
- if include_sft_targets:
114
- targets_block = (
115
- f"<|sft_targets|>eff_log={efficiency_label:.4f} "
116
- f"reward_clip={reward_clip:.2f}</s>\n"
117
- )
118
-
119
- return (
120
- f"<|system|>{system}</s>\n"
121
- f"{targets_block}"
122
- f"<|negotiation_history|>{history}</s>\n"
123
- f"<|assistant|>{response}</s>"
124
  )
 
 
 
 
 
125
 
126
 
127
  def train_sft(
128
- data_path: Path,
129
- model_id: str,
130
- output_dir: Path,
131
- threshold: float = 0.30,
132
- filter_cfg: SFTFilterConfig | None = None,
133
- include_sft_targets: bool = True,
134
- ) -> Path:
135
- """
136
- Run SFT fine-tuning.
137
-
138
- Args:
139
- data_path: Path to episodes JSONL.
140
- model_id: HuggingFace model ID or local path.
141
- output_dir: Where to save the trained model.
142
- threshold: Efficiency filter for training data.
143
- filter_cfg: Quality filter / clip config.
144
- include_sft_targets: Embed normalized targets in training strings.
145
-
146
- Returns:
147
- output_dir path.
148
- """
149
  import torch
150
- if not torch.cuda.is_available():
151
- logger.warning("No GPU detected SFT will be very slow on CPU")
152
-
153
- try:
154
- from peft import LoraConfig
155
- from trl import SFTTrainer, SFTConfig
156
- except ImportError as exc:
157
- raise ImportError("Install: pip install trl peft") from exc
158
 
159
- filter_cfg = filter_cfg or SFTFilterConfig()
160
- dataset = load_sft_dataset(
161
- data_path, threshold, filter_cfg, include_sft_targets=include_sft_targets
162
- )
163
- if len(dataset) == 0 and threshold > 0.0:
164
- logger.warning(
165
- f"No episodes above threshold {threshold}. Lowering to 0.0 (all train rows)."
166
- )
167
- dataset = load_sft_dataset(
168
- data_path, 0.0, filter_cfg, include_sft_targets=include_sft_targets
169
- )
170
- if len(dataset) == 0:
171
- raise RuntimeError(
172
- "SFT dataset is empty. "
173
- "Run generate_data.py first with --episodes >= 200"
174
- )
175
 
176
  lora_config = LoraConfig(
177
  r=16,
@@ -188,16 +130,16 @@ def train_sft(
188
  per_device_train_batch_size=4,
189
  gradient_accumulation_steps=4,
190
  learning_rate=2e-4,
191
- warmup_ratio=0.05,
192
- lr_scheduler_type="cosine",
193
  logging_steps=10,
194
  save_strategy="epoch",
195
- push_to_hub=False,
196
- bf16=torch.cuda.is_available(),
197
- max_seq_length=2048,
198
  report_to="none",
 
199
  )
200
 
 
 
 
201
  trainer = SFTTrainer(
202
  model=model_id,
203
  args=training_args,
@@ -205,46 +147,27 @@ def train_sft(
205
  peft_config=lora_config,
206
  )
207
 
208
- logger.info(f"Starting SFT training: model={model_id}, examples={len(dataset)}, epochs=3")
209
  trainer.train()
210
  trainer.save_model(str(output_dir))
211
- logger.info(f"SFT training complete. Model saved to {output_dir}")
212
- return output_dir
213
 
214
 
215
  def main() -> None:
216
- parser = argparse.ArgumentParser(description="Parlay SFT warmup training")
217
  parser.add_argument("--data", default="data/episodes.jsonl")
218
- parser.add_argument("--model", default=BASE_MODEL)
219
- parser.add_argument("--output", default="models/parlay-sft")
220
- parser.add_argument("--steps", type=int, default=0, help="Notebook compatibility flag")
221
- parser.add_argument("--threshold", type=float, default=TOP_PLAYER_THRESHOLD)
222
- parser.add_argument("--reward-drop-min", type=float, default=-400.0)
223
- parser.add_argument("--reward-drop-max", type=float, default=400.0)
224
- parser.add_argument("--clip-reward-min", type=float, default=-200.0)
225
- parser.add_argument("--clip-reward-max", type=float, default=200.0)
226
  parser.add_argument(
227
- "--no-sft-targets",
228
- action="store_true",
229
- help="Omit eff_log / reward_clip block from training text",
 
230
  )
231
  args = parser.parse_args()
232
 
233
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
234
- cfg = SFTFilterConfig(
235
- reward_drop_min=args.reward_drop_min,
236
- reward_drop_max=args.reward_drop_max,
237
- clip_reward_min=args.clip_reward_min,
238
- clip_reward_max=args.clip_reward_max,
239
- )
240
- train_sft(
241
- Path(args.data),
242
- args.model,
243
- Path(args.output),
244
- args.threshold,
245
- filter_cfg=cfg,
246
- include_sft_targets=not args.no_sft_targets,
247
- )
248
 
249
 
250
  if __name__ == "__main__":
 
1
  """
2
+ Run before grpo_train.py for SFT→GRPO pipeline. Pass checkpoint path as
3
+ BASE_MODEL env var to grpo_train.py.
 
 
 
 
 
 
 
 
 
 
4
  """
5
  import argparse
6
  import json
7
  import logging
 
8
  from pathlib import Path
9
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
13
+ DEFAULT_OUTPUT = "checkpoints/sft_1.5b/"
14
+
15
+
16
+ def _extract_completions(rec: dict) -> list[str]:
17
+ """Return candidate completion texts from a record."""
18
+ completion = rec.get("completion")
19
+ if isinstance(completion, str) and completion.strip():
20
+ return [completion.strip()]
21
 
22
+ conversation = rec.get("conversation", [])
23
+ candidates: list[str] = []
24
+ if isinstance(conversation, list):
25
+ for turn in conversation:
26
+ if not isinstance(turn, dict):
27
+ continue
28
+ role = str(turn.get("role", "")).lower()
29
+ content = str(turn.get("content", "")).strip()
30
+ if role == "negotiator" and content:
31
+ candidates.append(content)
32
+ return candidates
33
 
 
 
 
 
 
 
 
 
34
 
35
+ def _row_total_reward(rec: dict) -> float | None:
36
+ v = rec.get("reward")
37
+ if v is not None:
38
+ return float(v)
39
+ v2 = rec.get("cumulative_reward")
40
+ if v2 is not None:
41
+ return float(v2)
42
+ return None
43
 
 
 
 
 
 
44
 
45
+ def load_sft_dataset(data_path: Path, min_reward: float = -50.0):
46
+ """Build a text dataset from JSONL prompt/completion pairs."""
 
47
  try:
48
  from datasets import Dataset
49
  except ImportError as exc:
50
  raise ImportError("Install datasets: pip install datasets") from exc
51
 
52
+ rows: list[dict[str, str]] = []
53
+ skipped = 0
54
+ reward_filtered = 0
55
+ remaining_records = 0
56
+ with data_path.open("r", encoding="utf-8") as f:
57
+ for line_no, line in enumerate(f, start=1):
58
+ line = line.strip()
59
+ if not line:
 
60
  continue
61
+ try:
62
+ rec = json.loads(line)
63
+ except json.JSONDecodeError:
64
+ logger.warning("Skipping malformed JSONL row %d", line_no)
65
+ skipped += 1
66
+ continue
67
+
68
+ r = _row_total_reward(rec)
69
+ if r is not None and r < min_reward:
70
+ reward_filtered += 1
71
+ continue
72
+
73
+ prompt = str(rec.get("prompt", "")).strip()
74
+ if not prompt:
75
+ logger.warning("Skipping row %d: missing prompt", line_no)
76
+ skipped += 1
77
+ continue
78
+
79
+ completions = _extract_completions(rec)
80
+ if not completions:
81
+ logger.warning("Skipping row %d: missing completion and negotiator turns", line_no)
82
+ skipped += 1
83
+ continue
84
+
85
+ remaining_records += 1
86
+ for completion in completions:
87
+ rows.append(
88
+ {
89
+ "text": (
90
+ f"<|system|>{prompt}</s>\n"
91
+ f"<|assistant|>{completion}</s>"
92
+ )
93
+ }
94
+ )
95
+
96
+ print(
97
+ f"Filtered {reward_filtered} records below min_reward={min_reward}, "
98
+ f"{remaining_records} remaining for SFT"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  )
100
+ if skipped:
101
+ logger.info("Also skipped %d malformed/empty JSONL rows; expanded to %d text rows", skipped, len(rows))
102
+ if not rows:
103
+ raise RuntimeError("No valid SFT examples found in dataset.")
104
+ return Dataset.from_list(rows)
105
 
106
 
107
  def train_sft(
108
+ data_path: Path, model_id: str, output_dir: Path, min_reward: float = -50.0
109
+ ) -> None:
110
+ """Fine-tune a base model with LoRA via TRL SFTTrainer."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  import torch
112
+ from peft import LoraConfig
113
+ from trl import SFTConfig, SFTTrainer
 
 
 
 
 
 
114
 
115
+ dataset = load_sft_dataset(data_path, min_reward=min_reward)
116
+ output_dir.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  lora_config = LoraConfig(
119
  r=16,
 
130
  per_device_train_batch_size=4,
131
  gradient_accumulation_steps=4,
132
  learning_rate=2e-4,
 
 
133
  logging_steps=10,
134
  save_strategy="epoch",
135
+ fp16=True,
 
 
136
  report_to="none",
137
+ max_seq_length=2048,
138
  )
139
 
140
+ if not torch.cuda.is_available():
141
+ logger.warning("No CUDA GPU detected; training may be very slow.")
142
+
143
  trainer = SFTTrainer(
144
  model=model_id,
145
  args=training_args,
 
147
  peft_config=lora_config,
148
  )
149
 
150
+ logger.info("Starting SFT: model=%s, examples=%d", model_id, len(dataset))
151
  trainer.train()
152
  trainer.save_model(str(output_dir))
153
+ logger.info("Saved SFT checkpoint to %s", output_dir)
 
154
 
155
 
156
  def main() -> None:
157
+ parser = argparse.ArgumentParser(description="Parlay SFT training")
158
  parser.add_argument("--data", default="data/episodes.jsonl")
159
+ parser.add_argument("--model", default=DEFAULT_MODEL)
160
+ parser.add_argument("--output", default=DEFAULT_OUTPUT)
 
 
 
 
 
 
161
  parser.add_argument(
162
+ "--min-reward",
163
+ type=float,
164
+ default=-50.0,
165
+ help="Skip JSONL records with total reward below this (default: -50.0)",
166
  )
167
  args = parser.parse_args()
168
 
169
  logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
170
+ train_sft(Path(args.data), args.model, Path(args.output), min_reward=args.min_reward)
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  if __name__ == "__main__":