Add pre-training audit scripts, OpenEnv manifest, and tune Parlay training/env (GRPO 1.5B default, min-reward filters, weighted data gen, hiring ZOPA+drift, veteran/opponent prompts, Docker/docs)
Browse files- .dockerignore +13 -0
- .env.example +3 -26
- Dockerfile +5 -4
- Makefile +2 -1
- README.md +89 -695
- README_SPACES.md +14 -0
- agent/gemini_client.py +5 -1
- agent/personas.py +6 -1
- agent/runner.py +5 -1
- agent/tom_tracker.py +5 -0
- game/scenarios.py +10 -4
- openenv.yaml +130 -0
- parlay_env/grader.py +21 -2
- parlay_env/reward.py +6 -1
- parlay_env/server.py +51 -20
- scripts/audit_grpo_pipeline.py +177 -0
- scripts/audit_reward.py +163 -0
- scripts/check_training_config.py +148 -0
- scripts/eval_comparison.py +78 -0
- scripts/inspect_data.py +261 -0
- scripts/push_docker.sh +8 -0
- scripts/run_gemini_baseline.py +74 -0
- scripts/validate_sft_data.py +132 -0
- tests/test_keyless.py +20 -1
- training/generate_data.py +126 -151
- training/grpo_train.py +44 -13
- training/random_baseline.py +107 -94
- training/sft_train.py +105 -182
.dockerignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
__pycache__/
|
| 3 |
+
parlay_env/__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
venv/
|
| 6 |
+
venv-train/
|
| 7 |
+
node_modules/
|
| 8 |
+
data/*.jsonl
|
| 9 |
+
checkpoints/
|
| 10 |
+
|
| 11 |
+
# Keep Python source, drop generated artifacts under parlay_env/
|
| 12 |
+
parlay_env/
|
| 13 |
+
!parlay_env/**/*.py
|
.env.example
CHANGED
|
@@ -1,26 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
# Server ports
|
| 6 |
-
ENV_PORT=8001
|
| 7 |
-
DASHBOARD_PORT=8000
|
| 8 |
-
MCP_SSE_PORT=8002
|
| 9 |
-
|
| 10 |
-
# Game config
|
| 11 |
-
MAX_TURNS_PER_EPISODE=20
|
| 12 |
-
MIN_REWARD_THRESHOLD=-100
|
| 13 |
-
TOP_PLAYER_THRESHOLD=0.30
|
| 14 |
-
CREDIBILITY_POINTS_START=100
|
| 15 |
-
CREDIBILITY_REGEN_PER_TURN=5
|
| 16 |
-
|
| 17 |
-
# Training
|
| 18 |
-
BASE_MODEL=Qwen/Qwen2.5-7B-Instruct
|
| 19 |
-
GRPO_GENERATIONS=8
|
| 20 |
-
GRPO_STEPS=500
|
| 21 |
-
DATA_PATH=data/episodes.jsonl
|
| 22 |
-
SFT_OUTPUT=models/parlay-sft
|
| 23 |
-
GRPO_OUTPUT=models/parlay-grpo
|
| 24 |
-
|
| 25 |
-
# HF Hub
|
| 26 |
-
HF_REPO_ID=your-username/parlay-negotiator
|
|
|
|
| 1 |
+
GEMINI_API_KEY=your_gemini_key_here
|
| 2 |
+
HF_TOKEN=your_huggingface_token_here
|
| 3 |
+
BASE_MODEL=checkpoints/sft_1.5b/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Dockerfile
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Install deps first (layer-cached)
|
| 6 |
COPY requirements.txt .
|
|
@@ -11,10 +15,7 @@ COPY . .
|
|
| 11 |
# Initialise the database at build time
|
| 12 |
RUN python -m scripts.init_db
|
| 13 |
|
| 14 |
-
# startup script
|
| 15 |
-
RUN chmod +x scripts/start.sh
|
| 16 |
-
|
| 17 |
# HF Spaces exposes port 7860
|
| 18 |
EXPOSE 7860
|
| 19 |
|
| 20 |
-
CMD ["bash", "
|
|
|
|
| 1 |
FROM python:3.11-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
+
ENV PYTHONUNBUFFERED=1
|
| 5 |
+
ARG GEMINI_API_KEY=""
|
| 6 |
+
ENV GEMINI_API_KEY=${GEMINI_API_KEY}
|
| 7 |
+
ENV GOOGLE_API_KEY=${GEMINI_API_KEY}
|
| 8 |
|
| 9 |
# Install deps first (layer-cached)
|
| 10 |
COPY requirements.txt .
|
|
|
|
| 15 |
# Initialise the database at build time
|
| 16 |
RUN python -m scripts.init_db
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
# HF Spaces exposes port 7860
|
| 19 |
EXPOSE 7860
|
| 20 |
|
| 21 |
+
CMD ["bash", "-lc", "if [ -z \"$GOOGLE_API_KEY\" ] && [ -n \"$GEMINI_API_KEY\" ]; then export GOOGLE_API_KEY=\"$GEMINI_API_KEY\"; fi; uvicorn main:app --host 0.0.0.0 --port 7860"]
|
Makefile
CHANGED
|
@@ -34,7 +34,8 @@ test:
|
|
| 34 |
venv\Scripts\pytest tests/ -v
|
| 35 |
|
| 36 |
train-data:
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
train-sft:
|
| 40 |
venv-train\Scripts\python -m training.sft_train --model Qwen/Qwen2.5-7B-Instruct --data data/episodes.jsonl --output models/parlay-sft --threshold 0.30
|
|
|
|
| 34 |
venv\Scripts\pytest tests/ -v
|
| 35 |
|
| 36 |
train-data:
|
| 37 |
+
# hackathon budget default; override with EPISODES=N
|
| 38 |
+
venv-train\Scripts\python -m training.generate_data --episodes 80 --output data/episodes.jsonl
|
| 39 |
|
| 40 |
train-sft:
|
| 41 |
venv-train\Scripts\python -m training.sft_train --model Qwen/Qwen2.5-7B-Instruct --data data/episodes.jsonl --output models/parlay-sft --threshold 0.30
|
README.md
CHANGED
|
@@ -11,670 +11,158 @@ pinned: false
|
|
| 11 |
|
| 12 |
> **The arena where AIs learn to close.**
|
| 13 |
|
| 14 |
-
`Python 3.11` | `FastAPI` | `Gemini 2.
|
| 15 |
-
|
| 16 |
-
---
|
| 17 |
|
| 18 |
## Overview
|
| 19 |
|
| 20 |
-
Parlay is a
|
| 21 |
-
|
| 22 |
-
| Audience | What they get |
|
| 23 |
-
|---|---|
|
| 24 |
-
| **Hackathon Judges** | A fully playable browser game, an OpenEnv-compliant WebSocket server, an MCP integration layer, and a complete GRPO training pipeline — all in one repo |
|
| 25 |
-
| **Players** | A real-time negotiation game with five scenarios, five AI personas (Gemini-powered), Theory of Mind tracking, tactical cards, drift events, and a global leaderboard |
|
| 26 |
-
| **B2B / Researchers** | A clean OpenEnv protocol implementation for training negotiation agents; plug in your own model, collect episodes, run GRPO fine-tuning, push to HF Hub |
|
| 27 |
-
|
| 28 |
-
Parlay is built on:
|
| 29 |
-
- **Google Gemini 2.0 Flash** — the AI counterpart, generating persona-consistent responses in real time
|
| 30 |
-
- **FastAPI + aiosqlite** — async backend, zero ORM overhead, SQLite for portability
|
| 31 |
-
- **OpenEnv protocol** — standard `reset/step/state` WebSocket commands for agent interoperability
|
| 32 |
-
- **FastMCP** — universal MCP server supporting both `stdio` and SSE transports
|
| 33 |
-
- **HF TRL GRPOTrainer** — two-stage SFT → GRPO pipeline fine-tuning Qwen2.5-7B-Instruct
|
| 34 |
-
- **Vanilla JS + Three.js r128** — zero npm, zero build step, runs in any browser
|
| 35 |
-
|
| 36 |
-
---
|
| 37 |
-
|
| 38 |
-
## Quick Start
|
| 39 |
-
|
| 40 |
-
### Prerequisites
|
| 41 |
-
|
| 42 |
-
- Python **3.11** (required; training and game stacks expect 3.11)
|
| 43 |
-
- A Google AI Studio API key ([get one free](https://aistudio.google.com/app/apikey))
|
| 44 |
-
- (Optional) A Hugging Face token for training and model pushing
|
| 45 |
-
|
| 46 |
-
### Windows (recommended): PowerShell from repo root
|
| 47 |
-
|
| 48 |
-
All `scripts\*.ps1` files assume the **current directory is the project root**. If execution policy blocks scripts, run:
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
git clone https://github.com/your-username/parlay.git
|
| 54 |
-
cd parlay
|
| 55 |
-
.\scripts\setup.ps1
|
| 56 |
-
# Edit .env and set GOOGLE_API_KEY=your_key_here
|
| 57 |
-
.\scripts\run.ps1
|
| 58 |
-
```
|
| 59 |
-
|
| 60 |
-
Optional: `.\venv\Scripts\python -m scripts.seed_scenarios` for demo leaderboard rows. Training venv: `.\scripts\setup_train.ps1`, then `.\scripts\train_data.ps1` (and related `train_sft` / `train_grpo` / `evaluate` scripts).
|
| 61 |
-
|
| 62 |
-
You can also use **GNU Make** on Windows (e.g. Git Bash or Chocolatey): `make setup`, `make run`, `make test`, `make clean`. The Makefile uses `venv\Scripts\` paths.
|
| 63 |
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
```bash
|
| 69 |
-
cp .env.example .env
|
| 70 |
-
# Set GOOGLE_API_KEY in .env
|
| 71 |
-
docker compose up --build
|
| 72 |
-
```
|
| 73 |
-
|
| 74 |
-
See the [Docker](#docker) section for service URLs and training images.
|
| 75 |
-
|
| 76 |
-
### Open in browser (local server)
|
| 77 |
-
|
| 78 |
-
```
|
| 79 |
-
http://localhost:8000 # Game dashboard
|
| 80 |
-
http://localhost:8000/train # Training dashboard
|
| 81 |
-
http://localhost:8000/docs # Interactive API docs (Swagger)
|
| 82 |
-
```
|
| 83 |
|
| 84 |
-
|
| 85 |
|
| 86 |
```bash
|
| 87 |
-
git clone https://github.com/your-username/parlay.git
|
| 88 |
-
cd parlay
|
| 89 |
-
python3.11 -m venv venv
|
| 90 |
-
source venv/bin/activate
|
| 91 |
pip install -r requirements.txt
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
```
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
## API Keys
|
| 100 |
-
|
| 101 |
-
| Key | Required | Where to get it | Used for |
|
| 102 |
-
|---|---|---|---|
|
| 103 |
-
| `GOOGLE_API_KEY` | **Yes** | [Google AI Studio](https://aistudio.google.com/app/apikey) | Gemini 2.0 Flash (game AI + data gen) |
|
| 104 |
-
| `HF_TOKEN` | Only for training | [Hugging Face Settings](https://huggingface.co/settings/tokens) | Model push to HF Hub |
|
| 105 |
-
|
| 106 |
-
Set these in your `.env` file (never commit `.env`):
|
| 107 |
-
|
| 108 |
-
```bash
|
| 109 |
-
GOOGLE_API_KEY=AIzaSy...
|
| 110 |
-
HF_TOKEN=hf_...
|
| 111 |
-
```
|
| 112 |
-
|
| 113 |
-
---
|
| 114 |
-
|
| 115 |
-
## Project Structure
|
| 116 |
-
|
| 117 |
-
```
|
| 118 |
-
parlay/
|
| 119 |
-
├── main.py # FastAPI app entry point
|
| 120 |
-
│
|
| 121 |
-
├── parlay_env/ # Core RL environment (OpenEnv-compliant)
|
| 122 |
-
│ ├── __init__.py
|
| 123 |
-
│ ├── server.py # WebSocket router (reset/step/state endpoints)
|
| 124 |
-
│ ├── env.py # ParlayEnv class — OpenEnv implementation
|
| 125 |
-
│ ├── models.py # Pydantic models: ParlayState, ParlayAction, BeliefState…
|
| 126 |
-
│ ├── reward.py # Reward coefficients (ALPHA, BETA, GAMMA, OMEGA…)
|
| 127 |
-
│ ├── grader.py # Pure reward functions: step reward, terminal reward
|
| 128 |
-
│ ├── game_theory.py # ZOPA, Nash, Pareto, Shapley, Rubinstein computations
|
| 129 |
-
│ └── exceptions.py # Custom exceptions (InvalidScenarioError, CapitulationError…)
|
| 130 |
-
│
|
| 131 |
-
├── game/ # Game logic layer
|
| 132 |
-
│ ├── __init__.py
|
| 133 |
-
│ ├── scenarios.py # 5 negotiation scenarios with drift events
|
| 134 |
-
│ ├── personas.py # 5 AI personas with Gemini prompt templates
|
| 135 |
-
│ └── session.py # Session management: active games, turn routing
|
| 136 |
-
│
|
| 137 |
-
├── agent/ # AI agent components
|
| 138 |
-
│ ├── __init__.py
|
| 139 |
-
│ ├── gemini_client.py # Gemini 2.0 Flash async wrapper
|
| 140 |
-
│ ├── tom_tracker.py # Theory of Mind belief tracker
|
| 141 |
-
│ └── tactical.py # Tactical card execution logic
|
| 142 |
-
│
|
| 143 |
-
├── dashboard/ # Frontend (zero npm, zero build)
|
| 144 |
-
│ ├── index.html # Main game UI
|
| 145 |
-
│ ├── train.html # Training monitor UI
|
| 146 |
-
│ ├── api.py # FastAPI router for dashboard REST endpoints
|
| 147 |
-
│ └── static/
|
| 148 |
-
│ ├── app.js # Game WebSocket client + UI logic
|
| 149 |
-
│ ├── character.js # Three.js r128 animated persona character
|
| 150 |
-
│ ├── chart_utils.js # Chart.js reward visualization helpers
|
| 151 |
-
│ └── style.css # CSS with --parlay-* custom properties
|
| 152 |
-
│
|
| 153 |
-
├── mcp_server/ # MCP integration (stdio + SSE)
|
| 154 |
-
│ ├── __init__.py
|
| 155 |
-
│ └── server.py # FastMCP tools: negotiate, get_state, get_leaderboard…
|
| 156 |
-
│
|
| 157 |
-
├── training/ # Isolated training pipeline (never imported by game)
|
| 158 |
-
│ ├── __init__.py
|
| 159 |
-
│ ├── generate_data.py # Gemini self-play episode generation
|
| 160 |
-
│ ├── sft_train.py # SFTTrainer fine-tuning on top episodes
|
| 161 |
-
│ ├── grpo_train.py # GRPOTrainer RL fine-tuning
|
| 162 |
-
│ ├── reward_fn.py # GRPO reward functions (wraps grader.py)
|
| 163 |
-
│ ├── evaluate.py # Three-bar comparison chart: base vs SFT vs GRPO
|
| 164 |
-
│ └── push_to_hub.py # Upload model to HF Hub
|
| 165 |
-
│
|
| 166 |
-
├── scripts/
|
| 167 |
-
│ ├── __init__.py
|
| 168 |
-
│ ├── init_db.py # Create SQLite schema (idempotent)
|
| 169 |
-
│ ├── seed_scenarios.py # Insert demo leaderboard entries
|
| 170 |
-
│ ├── setup.ps1 # Windows: game venv + .env + DB init
|
| 171 |
-
│ ├── setup_train.ps1 # Windows: training venv (PyTorch, TRL, …)
|
| 172 |
-
│ ├── run.ps1 / run_env.ps1 / run_mcp.ps1
|
| 173 |
-
│ ├── train_*.ps1 / evaluate.ps1 / test.ps1
|
| 174 |
-
│
|
| 175 |
-
├── tests/
|
| 176 |
-
│ ├── __init__.py
|
| 177 |
-
│ ├── test_grader.py # Reward computation tests
|
| 178 |
-
│ ├── test_game_theory.py # ZOPA/Nash/Pareto/Shapley tests
|
| 179 |
-
│ ├── test_tom.py # Theory of Mind tracker tests
|
| 180 |
-
│ ├── test_reward.py # Reward constants tests
|
| 181 |
-
│ └── test_scenarios.py # Scenario definition tests
|
| 182 |
-
│
|
| 183 |
-
├── data/ # Generated episode JSONL files (gitignored)
|
| 184 |
-
├── models/ # Fine-tuned model checkpoints (gitignored)
|
| 185 |
-
├── results/ # Evaluation charts and metrics (gitignored)
|
| 186 |
-
│
|
| 187 |
-
├── requirements.txt # Core dependencies
|
| 188 |
-
├── requirements-train.txt # Training-only dependencies (torch, trl, peft…)
|
| 189 |
-
├── Makefile # Windows-oriented targets (venv\\Scripts\\ paths)
|
| 190 |
-
├── .gitattributes # LF for source; CRLF for .ps1
|
| 191 |
-
├── .env.example # Environment variable template
|
| 192 |
-
├── .gitignore
|
| 193 |
-
├── docker-compose.yml # Multi-service Docker deployment
|
| 194 |
-
├── Dockerfile.game # Game + dashboard service
|
| 195 |
-
├── Dockerfile.env # OpenEnv WebSocket service
|
| 196 |
-
└── Dockerfile.train # GRPO training service (CUDA)
|
| 197 |
-
```
|
| 198 |
-
|
| 199 |
-
---
|
| 200 |
-
|
| 201 |
-
## Game Guide
|
| 202 |
-
|
| 203 |
-
### How to Play
|
| 204 |
|
| 205 |
-
|
| 206 |
-
2. **Choose your persona style** — affects how aggressively the AI counterpart responds
|
| 207 |
-
3. **Negotiate in natural language** — type your offers and arguments in the chat
|
| 208 |
-
4. **Use tactical cards** — spend Credibility Points to play power moves (anchor, BATNA reveal, deadline pressure)
|
| 209 |
-
5. **Watch for drift events** — the AI's hidden priorities shift mid-negotiation; adapt or lose ground
|
| 210 |
-
6. **Close within 20 turns** — speed bonuses reward efficient closers
|
| 211 |
|
| 212 |
-
|
| 213 |
|
| 214 |
-
|
| 215 |
-
|---|---|
|
| 216 |
-
| **ZOPA** | Zone of Possible Agreement — the range between both parties' walk-away prices where a deal is mutually beneficial |
|
| 217 |
-
| **BATNA** | Best Alternative to a Negotiated Agreement — your outside option; the floor below which you'd rather walk away |
|
| 218 |
-
| **Nash Bargaining Solution** | The game-theoretically "fair" split of the ZOPA surplus — the midpoint of both BATNAs |
|
| 219 |
-
| **Anchor** | Your opening offer. The higher you anchor (as seller), the more the counterpart adjusts from that reference point |
|
| 220 |
-
| **Rubinstein Deadline** | The advantage of having more time — patient negotiators extract better deals |
|
| 221 |
-
| **Capitulation Cliff** | Accepting below your BATNA triggers a hard -150 penalty (OMEGA). Never capitulate |
|
| 222 |
-
| **Theory of Mind** | Parlay tracks the AI's inferred beliefs about you — high ToM accuracy gives a step reward bonus |
|
| 223 |
-
| **Drift Event** | A mid-game shock (budget cut, competitor offer, urgency spike) that changes the AI's hidden priorities |
|
| 224 |
|
| 225 |
-
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|---|---|---|
|
| 229 |
-
| **Anchor High** | 10 CP | Lock in a high reference price — reduces AI's willingness to counter aggressively |
|
| 230 |
-
| **BATNA Reveal** | 15 CP | Signal your outside option — increases AI urgency if credible |
|
| 231 |
-
| **Deadline Pressure** | 20 CP | Introduce artificial urgency — accelerates AI concessions by 15% |
|
| 232 |
-
| **Bundle Offer** | 12 CP | Add non-monetary value — expands the ZOPA by shifting AI utility |
|
| 233 |
-
| **Silent Close** | 25 CP | Make a final offer with no further negotiation signal — high risk, high reward |
|
| 234 |
-
| **Coalition Play** | 30 CP | Invoke Act 3 coalition mechanics — brings in a third party for multi-issue negotiation |
|
| 235 |
|
| 236 |
-
|
| 237 |
|
| 238 |
-
|
| 239 |
|
| 240 |
-
```
|
| 241 |
-
|
| 242 |
-
```
|
| 243 |
-
|
| 244 |
-
Deal Efficiency is displayed as a percentage:
|
| 245 |
-
```
|
| 246 |
-
Deal Efficiency = (Final Price - Seller BATNA) / (Buyer BATNA - Seller BATNA)
|
| 247 |
-
```
|
| 248 |
-
|
| 249 |
-
A deal efficiency of 1.0 means you captured the full ZOPA surplus. 0.5 is the Nash fair split.
|
| 250 |
-
|
| 251 |
-
---
|
| 252 |
-
|
| 253 |
-
## OpenEnv Protocol
|
| 254 |
-
|
| 255 |
-
Parlay implements the OpenEnv standard for RL environments over WebSocket.
|
| 256 |
-
|
| 257 |
-
### Connection
|
| 258 |
-
|
| 259 |
-
```
|
| 260 |
-
ws://localhost:8000/env/ws/{session_id}
|
| 261 |
-
```
|
| 262 |
-
|
| 263 |
-
### Commands
|
| 264 |
-
|
| 265 |
-
#### `reset` — Start a new episode
|
| 266 |
-
|
| 267 |
-
```json
|
| 268 |
-
{
|
| 269 |
-
"command": "reset",
|
| 270 |
-
"scenario_id": "saas_enterprise",
|
| 271 |
-
"persona": "shark",
|
| 272 |
-
"player_name": "MyAgent"
|
| 273 |
-
}
|
| 274 |
-
```
|
| 275 |
-
|
| 276 |
-
Response:
|
| 277 |
-
```json
|
| 278 |
-
{
|
| 279 |
-
"type": "observation",
|
| 280 |
-
"session_id": "abc-123",
|
| 281 |
-
"scenario_id": "saas_enterprise",
|
| 282 |
-
"persona": "shark",
|
| 283 |
-
"act": 1,
|
| 284 |
-
"step_count": 0,
|
| 285 |
-
"belief": {
|
| 286 |
-
"est_budget": 140000,
|
| 287 |
-
"est_walk_away": 125000,
|
| 288 |
-
"est_urgency": 0.4,
|
| 289 |
-
"est_has_alternative": false,
|
| 290 |
-
"confidence": 0.3
|
| 291 |
-
},
|
| 292 |
-
"credibility_points": 100,
|
| 293 |
-
"offer_history": [],
|
| 294 |
-
"episode_done": false
|
| 295 |
-
}
|
| 296 |
-
```
|
| 297 |
-
|
| 298 |
-
#### `step` — Take a negotiation turn
|
| 299 |
-
|
| 300 |
-
```json
|
| 301 |
-
{
|
| 302 |
-
"command": "step",
|
| 303 |
-
"utterance": "I propose an annual contract at $155,000 with a 90-day payment term.",
|
| 304 |
-
"offer_amount": 155000,
|
| 305 |
-
"tactical_move": null
|
| 306 |
-
}
|
| 307 |
-
```
|
| 308 |
-
|
| 309 |
-
Response:
|
| 310 |
-
```json
|
| 311 |
-
{
|
| 312 |
-
"type": "step_result",
|
| 313 |
-
"step_reward": 12.4,
|
| 314 |
-
"cumulative_reward": 12.4,
|
| 315 |
-
"ai_response": "That's ambitious. Our budget ceiling won't stretch that far...",
|
| 316 |
-
"belief": { "...updated belief state..." },
|
| 317 |
-
"tension_score": 0.6,
|
| 318 |
-
"drift_fired": false,
|
| 319 |
-
"episode_done": false
|
| 320 |
-
}
|
| 321 |
-
```
|
| 322 |
-
|
| 323 |
-
#### `state` — Query current state without acting
|
| 324 |
-
|
| 325 |
-
```json
|
| 326 |
-
{ "command": "state" }
|
| 327 |
-
```
|
| 328 |
-
|
| 329 |
-
#### `close` — Accept final price and end episode
|
| 330 |
-
|
| 331 |
-
```json
|
| 332 |
-
{
|
| 333 |
-
"command": "close",
|
| 334 |
-
"final_price": 148000
|
| 335 |
-
}
|
| 336 |
-
```
|
| 337 |
-
|
| 338 |
-
Response:
|
| 339 |
-
```json
|
| 340 |
-
{
|
| 341 |
-
"type": "episode_done",
|
| 342 |
-
"total_reward": 287.4,
|
| 343 |
-
"deal_efficiency": 0.82,
|
| 344 |
-
"acts_completed": 2,
|
| 345 |
-
"bluffs_caught": 1,
|
| 346 |
-
"drift_adapted": true,
|
| 347 |
-
"deal_closed": true,
|
| 348 |
-
"leaderboard_rank": 3
|
| 349 |
-
}
|
| 350 |
-
```
|
| 351 |
-
|
| 352 |
-
### HTTP REST Endpoints
|
| 353 |
-
|
| 354 |
-
| Method | Path | Description |
|
| 355 |
-
|---|---|---|
|
| 356 |
-
| `GET` | `/health` | Health check |
|
| 357 |
-
| `GET` | `/env/scenarios` | List all scenarios |
|
| 358 |
-
| `GET` | `/env/personas` | List all personas |
|
| 359 |
-
| `GET` | `/dashboard/leaderboard` | Global leaderboard |
|
| 360 |
-
| `GET` | `/dashboard/leaderboard/{scenario_id}` | Per-scenario leaderboard |
|
| 361 |
-
| `POST` | `/dashboard/submit` | Submit episode result |
|
| 362 |
-
| `GET` | `/docs` | Swagger UI |
|
| 363 |
-
|
| 364 |
-
---
|
| 365 |
-
|
| 366 |
-
## MCP Setup
|
| 367 |
-
|
| 368 |
-
Parlay ships a universal MCP server supporting **both** `stdio` and SSE transports.
|
| 369 |
-
|
| 370 |
-
### Available MCP Tools
|
| 371 |
-
|
| 372 |
-
| Tool | Description |
|
| 373 |
-
|---|---|
|
| 374 |
-
| `negotiate` | Send a negotiation message and get the AI's response |
|
| 375 |
-
| `get_state` | Retrieve current session state and belief model |
|
| 376 |
-
| `reset_session` | Start a new negotiation session |
|
| 377 |
-
| `close_deal` | Accept a final price and get episode grade |
|
| 378 |
-
| `get_leaderboard` | Fetch top performers globally or by scenario |
|
| 379 |
-
| `list_scenarios` | Get all available scenarios with ZOPA ranges |
|
| 380 |
-
| `list_personas` | Get all personas with strategy profiles |
|
| 381 |
-
| `get_game_theory` | Compute ZOPA, Nash point, Rubinstein advantage for any deal |
|
| 382 |
-
|
| 383 |
-
### Client 1: Claude Desktop / Claude Code
|
| 384 |
-
|
| 385 |
-
Add to your `claude_desktop_config.json` (usually at `~/Library/Application Support/Claude/claude_desktop_config.json` on macOS):
|
| 386 |
-
|
| 387 |
-
```json
|
| 388 |
-
{
|
| 389 |
-
"mcpServers": {
|
| 390 |
-
"parlay": {
|
| 391 |
-
"command": "python",
|
| 392 |
-
"args": ["-m", "mcp_server.server", "stdio"],
|
| 393 |
-
"cwd": "/path/to/parlay",
|
| 394 |
-
"env": {
|
| 395 |
-
"GOOGLE_API_KEY": "your_key_here"
|
| 396 |
-
}
|
| 397 |
-
}
|
| 398 |
-
}
|
| 399 |
-
}
|
| 400 |
-
```
|
| 401 |
-
|
| 402 |
-
### Client 2: Continue.dev / Zed / Any SSE Client
|
| 403 |
-
|
| 404 |
-
First start the SSE server:
|
| 405 |
-
|
| 406 |
-
```bash
|
| 407 |
-
python -m mcp_server.server sse
|
| 408 |
-
# Listening on http://localhost:8002/sse
|
| 409 |
-
```
|
| 410 |
-
|
| 411 |
-
Then configure your client to point at:
|
| 412 |
-
|
| 413 |
-
```
|
| 414 |
-
http://localhost:8002/sse
|
| 415 |
-
```
|
| 416 |
-
|
| 417 |
-
In `Continue.dev` (`~/.continue/config.json`):
|
| 418 |
-
|
| 419 |
-
```json
|
| 420 |
-
{
|
| 421 |
-
"experimental": {
|
| 422 |
-
"modelContextProtocolServers": [
|
| 423 |
-
{
|
| 424 |
-
"transport": {
|
| 425 |
-
"type": "sse",
|
| 426 |
-
"url": "http://localhost:8002/sse"
|
| 427 |
-
}
|
| 428 |
-
}
|
| 429 |
-
]
|
| 430 |
-
}
|
| 431 |
-
}
|
| 432 |
-
```
|
| 433 |
-
|
| 434 |
-
### Client 3: Generic stdio (any MCP-compatible agent)
|
| 435 |
-
|
| 436 |
-
```bash
|
| 437 |
-
python -m mcp_server.server stdio
|
| 438 |
-
```
|
| 439 |
-
|
| 440 |
-
Pipe JSON-RPC messages to stdin; responses arrive on stdout. Compatible with any MCP client library.
|
| 441 |
-
|
| 442 |
-
---
|
| 443 |
|
| 444 |
## Training Pipeline
|
| 445 |
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
- Minimum 20 episodes per (persona × scenario) pair = 500 baseline
|
| 458 |
-
- 30% noise injection for exploration
|
| 459 |
-
- 40% forced drift event rate
|
| 460 |
-
- 25% Act 3 coalition scenarios
|
| 461 |
-
|
| 462 |
-
Each episode record:
|
| 463 |
-
```json
|
| 464 |
-
{
|
| 465 |
-
"prompt": "You are negotiating a SaaS enterprise deal...",
|
| 466 |
-
"conversation": [...],
|
| 467 |
-
"reward": 247.3,
|
| 468 |
-
"deal_efficiency": 0.79,
|
| 469 |
-
"persona": "shark",
|
| 470 |
-
"scenario_id": "saas_enterprise",
|
| 471 |
-
"acts_completed": 2,
|
| 472 |
-
"tom_accuracy": 0.81,
|
| 473 |
-
"drift_adapted": true,
|
| 474 |
-
"split": "train"
|
| 475 |
-
}
|
| 476 |
```
|
| 477 |
|
| 478 |
-
###
|
| 479 |
-
|
| 480 |
-
Train Qwen2.5-7B-Instruct on the top 60% of episodes by reward:
|
| 481 |
|
| 482 |
```bash
|
| 483 |
-
python -m training.
|
| 484 |
-
--model Qwen/Qwen2.5-7B-Instruct \
|
| 485 |
-
--data data/episodes.jsonl \
|
| 486 |
-
--output models/parlay-sft \
|
| 487 |
-
--threshold 0.60
|
| 488 |
```
|
| 489 |
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
### Stage 3: GRPO Fine-Tuning
|
| 493 |
-
|
| 494 |
-
Apply Group Relative Policy Optimization with G=8 generations per prompt:
|
| 495 |
|
| 496 |
```bash
|
| 497 |
-
python -m training.
|
| 498 |
-
--model models/parlay-sft \
|
| 499 |
-
--data data/episodes.jsonl \
|
| 500 |
-
--output models/parlay-grpo \
|
| 501 |
-
--steps 500
|
| 502 |
```
|
| 503 |
|
| 504 |
-
|
| 505 |
-
- `num_generations=8` (G=8 per prompt)
|
| 506 |
-
- `beta=0.001` (low KL coefficient — allows exploration)
|
| 507 |
-
- `epsilon=0.2` (clipping range)
|
| 508 |
-
- `scale_rewards="batch"` (batch-level reward standardization)
|
| 509 |
-
- `learning_rate=5e-7`
|
| 510 |
-
|
| 511 |
-
### Stage 4: Evaluate
|
| 512 |
|
| 513 |
```bash
|
| 514 |
-
python -m training.
|
| 515 |
-
--base Qwen/Qwen2.5-7B-Instruct \
|
| 516 |
-
--sft models/parlay-sft \
|
| 517 |
-
--grpo models/parlay-grpo \
|
| 518 |
-
--data data/episodes.jsonl \
|
| 519 |
-
--output results/eval_results.json
|
| 520 |
```
|
| 521 |
|
| 522 |
-
|
| 523 |
|
| 524 |
-
|
| 525 |
|
| 526 |
-
``
|
| 527 |
-
python -m training.push_to_hub \
|
| 528 |
-
--model models/parlay-grpo \
|
| 529 |
-
--repo your-username/parlay-negotiator
|
| 530 |
-
```
|
| 531 |
-
|
| 532 |
-
Requires `HF_TOKEN` and `HF_REPO_ID` in `.env`.
|
| 533 |
-
|
| 534 |
-
---
|
| 535 |
|
| 536 |
-
##
|
| 537 |
|
| 538 |
-
|
| 539 |
|
| 540 |
-
|
| 541 |
-
|---|---|---|---|---|
|
| 542 |
-
| **Shark** | 0.90 | 0.20 | 0.45 | Opens high, concedes slowly, uses deadline pressure, willing to walk away |
|
| 543 |
-
| **Diplomat** | 0.30 | 0.80 | 0.10 | Relationship-focused, seeks mutual gain, rarely bluffs, prefers bundle deals |
|
| 544 |
-
| **Analyst** | 0.50 | 0.70 | 0.15 | Data-driven, requests justification for every number, ZOPA-aware, systematic |
|
| 545 |
-
| **Veteran** | 0.65 | 0.85 | 0.30 | Pattern-recognizes anchors, absorbs pressure, uses silence as a tool |
|
| 546 |
-
| **Wildcard** | 0.75 | 0.35 | 0.55 | Unpredictable, drift-prone, high bluff rate, can pivot strategy mid-negotiation |
|
| 547 |
|
| 548 |
-
|
| 549 |
|
| 550 |
-
|
| 551 |
|
| 552 |
-
|
| 553 |
|
| 554 |
-
|
| 555 |
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
|
|
|
| 563 |
|
| 564 |
-
|
| 565 |
-
- `batna_buyer`: Buyer's walk-away ceiling
|
| 566 |
-
- `batna_seller`: Seller's walk-away floor
|
| 567 |
-
- `anchor_buyer`: Typical buyer opening offer
|
| 568 |
-
- `anchor_seller`: Typical seller opening ask
|
| 569 |
-
- `drift_events`: List of mid-game shocks with trigger turns and effects
|
| 570 |
-
- `currency`: Always USD
|
| 571 |
-
- `difficulty`: `low | medium | high | very_high`
|
| 572 |
|
| 573 |
-
|
| 574 |
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
The Parlay grader computes rewards in two phases:
|
| 578 |
-
|
| 579 |
-
### Step Reward (per turn)
|
| 580 |
-
|
| 581 |
-
```
|
| 582 |
-
r_step = α · ΔZOPA_position
|
| 583 |
-
+ β · ToM_accuracy_improvement
|
| 584 |
-
- δ · concession_magnitude
|
| 585 |
-
- θ · noise_penalty
|
| 586 |
-
+ ε · tactical_card_bonus
|
| 587 |
```
|
| 588 |
|
| 589 |
-
|
| 590 |
-
- **α (ALPHA = 2.0)** — reward for improving your ZOPA position
|
| 591 |
-
- **β (BETA = 5.0)** — reward for improving ToM belief accuracy
|
| 592 |
-
- **δ (DELTA = 1.5)** — penalty per unit of concession from previous offer
|
| 593 |
-
- **θ (THETA = 3.0)** — penalty for low-grounding utterances (noise)
|
| 594 |
-
- **ε (EPSILON = 8.0)** — bonus for successful tactical card execution
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
```
|
| 599 |
-
r_terminal =
|
| 600 |
-
if final_price < batna_seller: -Ω (capitulation cliff: -150)
|
| 601 |
-
elif deal_closed:
|
| 602 |
-
Γ (base close bonus: +100)
|
| 603 |
-
+ ζ · deal_efficiency (ZOPA capture: up to +50)
|
| 604 |
-
+ η · acts_completed (multi-act bonus: +10/act)
|
| 605 |
-
+ Γ · (1 - t_close/t_max) (speed bonus: up to +100)
|
| 606 |
-
+ ETA · drift_adapted (drift adaptation: +10)
|
| 607 |
-
else (no deal):
|
| 608 |
-
-Γ/2 + β · avg_tom_accuracy (partial credit)
|
| 609 |
```
|
| 610 |
|
| 611 |
-
|
| 612 |
-
- **Γ (GAMMA = 100.0)** — primary close bonus
|
| 613 |
-
- **ζ (ZETA = 50.0)** — ZOPA efficiency multiplier
|
| 614 |
-
- **η (ETA = 10.0)** — per-act completion bonus (max 3 acts = +30)
|
| 615 |
-
- **Ω (OMEGA = 150.0)** — capitulation cliff penalty
|
| 616 |
-
- **t_close** — turn at which deal was closed
|
| 617 |
-
- **t_max** — maximum turns (default: 20)
|
| 618 |
-
|
| 619 |
-
All coefficients live exclusively in `parlay_env/reward.py`. Never hardcode them elsewhere.
|
| 620 |
-
|
| 621 |
-
---
|
| 622 |
-
|
| 623 |
-
## Docker
|
| 624 |
-
|
| 625 |
-
### Run all services
|
| 626 |
|
| 627 |
```bash
|
| 628 |
-
|
| 629 |
-
# Set GOOGLE_API_KEY in .env
|
| 630 |
-
|
| 631 |
-
docker compose up --build
|
| 632 |
```
|
| 633 |
|
| 634 |
-
|
| 635 |
-
- `game` → `http://localhost:8000` — game dashboard + API
|
| 636 |
-
- `env` → `http://localhost:8001` — OpenEnv WebSocket server
|
| 637 |
-
- `mcp` → `http://localhost:8002` — MCP SSE server
|
| 638 |
-
|
| 639 |
-
### Run training (requires GPU)
|
| 640 |
|
| 641 |
```bash
|
| 642 |
-
|
| 643 |
-
docker run --gpus all -v $(pwd)/data:/app/data -v $(pwd)/models:/app/models \
|
| 644 |
-
-e GOOGLE_API_KEY=$GOOGLE_API_KEY \
|
| 645 |
-
-e HF_TOKEN=$HF_TOKEN \
|
| 646 |
-
parlay-train python -m training.grpo_train --steps 500
|
| 647 |
```
|
| 648 |
|
| 649 |
-
###
|
| 650 |
|
| 651 |
```bash
|
| 652 |
-
|
| 653 |
-
docker
|
| 654 |
-
docker run -p 8000:8000 -e GOOGLE_API_KEY=$GOOGLE_API_KEY parlay-game
|
| 655 |
-
|
| 656 |
-
# OpenEnv only
|
| 657 |
-
docker build -f Dockerfile.env -t parlay-env .
|
| 658 |
-
docker run -p 8001:8001 -e GOOGLE_API_KEY=$GOOGLE_API_KEY parlay-env
|
| 659 |
```
|
| 660 |
|
| 661 |
-
---
|
| 662 |
-
|
| 663 |
## Testing
|
| 664 |
|
| 665 |
-
###
|
| 666 |
|
| 667 |
```bash
|
| 668 |
pytest tests/ -v
|
| 669 |
```
|
| 670 |
|
| 671 |
-
###
|
| 672 |
-
|
| 673 |
-
```bash
|
| 674 |
-
pytest tests/ -v --tb=short --cov=parlay_env --cov=game --cov=agent --cov-report=term-missing
|
| 675 |
-
```
|
| 676 |
-
|
| 677 |
-
### Run a specific test module
|
| 678 |
|
| 679 |
```bash
|
| 680 |
pytest tests/test_grader.py -v
|
|
@@ -684,122 +172,28 @@ pytest tests/test_reward.py -v
|
|
| 684 |
pytest tests/test_scenarios.py -v
|
| 685 |
```
|
| 686 |
|
| 687 |
-
###
|
| 688 |
-
|
| 689 |
-
| File | What it tests |
|
| 690 |
-
|---|---|
|
| 691 |
-
| `test_grader.py` | Step reward, terminal reward, episode grade computation |
|
| 692 |
-
| `test_game_theory.py` | ZOPA, Nash bargaining, Pareto frontier, Shapley value, anchoring, Rubinstein |
|
| 693 |
-
| `test_tom.py` | Theory of Mind tracker: belief updates, bluff detection, drift events, accuracy |
|
| 694 |
-
| `test_reward.py` | Reward coefficient constants and their mathematical constraints |
|
| 695 |
-
| `test_scenarios.py` | Scenario definitions: ZOPA validity, drift events, currency, IDs |
|
| 696 |
-
|
| 697 |
-
All tests follow the pattern: `Test{Module}` class → `test_{scenario}` methods → `assert ... f"Expected {expected}, got {result}"`.
|
| 698 |
-
|
| 699 |
-
---
|
| 700 |
-
|
| 701 |
-
## Architecture Decisions
|
| 702 |
-
|
| 703 |
-
### Why SQLite over Postgres?
|
| 704 |
-
|
| 705 |
-
Parlay is designed to be a **zero-infrastructure hackathon demo**. SQLite with `aiosqlite` provides full async support, requires no Docker service for the database, and the `parlay.db` file can be committed for demo snapshots. Migrating to Postgres requires only changing the connection string.
|
| 706 |
-
|
| 707 |
-
### Why Vanilla JS over React/Vue?
|
| 708 |
-
|
| 709 |
-
The `.cursorrules` mandate: zero npm, zero build step. Three.js r128 from cdnjs gives us 3D animated personas. Chart.js 4.4 gives us reward curves. `fetch()` + `WebSocket` gives us real-time game state. The entire frontend loads from a single HTML file with `<script>` tags. This means anyone can open the dashboard without `node_modules`.
|
| 710 |
-
|
| 711 |
-
### Why GRPO over PPO?
|
| 712 |
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
-
##
|
| 716 |
|
| 717 |
-
|
| 718 |
-
- Sub-500ms latency for negotiation turns with `max_output_tokens=500`
|
| 719 |
-
- Strong instruction-following for persona-consistent responses
|
| 720 |
-
- Async-compatible via `run_in_executor` pattern
|
| 721 |
-
|
| 722 |
-
---
|
| 723 |
-
|
| 724 |
-
## Environment Variables Reference
|
| 725 |
-
|
| 726 |
-
| Variable | Default | Description |
|
| 727 |
-
|---|---|---|
|
| 728 |
-
| `GOOGLE_API_KEY` | — | **Required.** Google AI Studio API key |
|
| 729 |
-
| `HF_TOKEN` | — | Hugging Face token (training only) |
|
| 730 |
-
| `ENV_PORT` | `8001` | OpenEnv WebSocket server port |
|
| 731 |
-
| `DASHBOARD_PORT` | `8000` | Dashboard + game server port |
|
| 732 |
-
| `MCP_SSE_PORT` | `8002` | MCP SSE server port |
|
| 733 |
-
| `MAX_TURNS_PER_EPISODE` | `20` | Maximum turns before episode ends |
|
| 734 |
-
| `MIN_REWARD_THRESHOLD` | `-100` | Minimum reward for SFT data inclusion |
|
| 735 |
-
| `TOP_PLAYER_THRESHOLD` | `0.60` | Percentile cutoff for SFT training data |
|
| 736 |
-
| `CREDIBILITY_POINTS_START` | `100` | Starting CP for tactical cards |
|
| 737 |
-
| `CREDIBILITY_REGEN_PER_TURN` | `5` | CP regenerated each turn |
|
| 738 |
-
| `BASE_MODEL` | `Qwen/Qwen2.5-7B-Instruct` | HF model ID for training base |
|
| 739 |
-
| `GRPO_GENERATIONS` | `8` | G value for GRPO (generations per prompt) |
|
| 740 |
-
| `GRPO_STEPS` | `500` | GRPO training steps |
|
| 741 |
-
| `DATA_PATH` | `data/episodes.jsonl` | Episode data for training |
|
| 742 |
-
| `SFT_OUTPUT` | `models/parlay-sft` | SFT checkpoint output path |
|
| 743 |
-
| `GRPO_OUTPUT` | `models/parlay-grpo` | GRPO checkpoint output path |
|
| 744 |
-
| `HF_REPO_ID` | — | HF Hub repo for model push |
|
| 745 |
-
|
| 746 |
-
---
|
| 747 |
-
|
| 748 |
-
## Testing Without API Keys
|
| 749 |
-
|
| 750 |
-
Everything in Parlay runs in **mock mode** when `GOOGLE_API_KEY` is not set.
|
| 751 |
-
Mock mode returns canned persona-consistent responses so you can play and test
|
| 752 |
-
the full game loop without any external account.
|
| 753 |
|
| 754 |
```bash
|
| 755 |
-
|
| 756 |
-
make setup
|
| 757 |
-
|
| 758 |
-
# 2. Run the keyless test suite (zero API calls)
|
| 759 |
-
make test-keyless
|
| 760 |
-
|
| 761 |
-
# 3. Start the server in mock mode
|
| 762 |
-
make run
|
| 763 |
-
|
| 764 |
-
# 4. Open the game in your browser
|
| 765 |
-
# → http://localhost:8000
|
| 766 |
-
# A "Demo mode" banner confirms mock mode is active.
|
| 767 |
```
|
| 768 |
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
To test the exact HF Spaces container locally before pushing:
|
| 772 |
|
| 773 |
```bash
|
| 774 |
-
|
| 775 |
-
# → http://localhost:7860
|
| 776 |
```
|
| 777 |
|
| 778 |
-
---
|
| 779 |
-
|
| 780 |
-
## Contributing
|
| 781 |
-
|
| 782 |
-
1. Fork the repo and create a feature branch
|
| 783 |
-
2. Follow the module dependency graph: `training/ → parlay_env/ → game/ → agent/`
|
| 784 |
-
3. Add type hints and docstrings to all public functions
|
| 785 |
-
4. Write at least 2 tests per new function (happy path + edge case)
|
| 786 |
-
5. Run `pytest tests/` — all tests must pass
|
| 787 |
-
6. Verify `docker compose up --build` completes without errors
|
| 788 |
-
|
| 789 |
-
---
|
| 790 |
-
|
| 791 |
## License
|
| 792 |
|
| 793 |
-
MIT
|
| 794 |
-
|
| 795 |
-
Copyright (c) 2026 Parlay Contributors
|
| 796 |
-
|
| 797 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 798 |
-
|
| 799 |
-
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 800 |
-
|
| 801 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 802 |
-
|
| 803 |
-
---
|
| 804 |
-
|
| 805 |
-
*Built for the Meta Hackathon 2026. Powered by Gemini 2.0 Flash + Qwen2.5-7B.*
|
|
|
|
| 11 |
|
| 12 |
> **The arena where AIs learn to close.**
|
| 13 |
|
| 14 |
+
`Python 3.11` | `FastAPI` | `Gemini 2.5 Flash` | `GRPO` | `OpenEnv-style WS`
|
|
|
|
|
|
|
| 15 |
|
| 16 |
## Overview
|
| 17 |
|
| 18 |
+
Parlay is a negotiation RL environment + browser game + training stack:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
- Three negotiation scenarios and three personas.
|
| 21 |
+
- OpenEnv-style WebSocket interface (`reset` / `step` / `state`) on `/env/ws`.
|
| 22 |
+
- Theory-of-Mind belief tracking with dense reward shaping.
|
| 23 |
+
- Dynamic ZOPA erosion under sustained tension.
|
| 24 |
+
- Training pipeline from Gemini self-play data to SFT and GRPO.
|
| 25 |
|
| 26 |
+
Gemini model routing:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
- `gemini-2.5-flash-lite` for data generation and self-play.
|
| 29 |
+
- `gemini-2.5-flash` for demo gameplay and MCP tools.
|
| 30 |
|
| 31 |
+
## Quickstart
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
Run exactly in this order:
|
| 34 |
|
| 35 |
```bash
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
pip install -r requirements.txt
|
| 37 |
+
export GEMINI_API_KEY=your_key
|
| 38 |
+
uvicorn main:app --port 8000
|
| 39 |
+
open http://localhost:8000
|
| 40 |
```
|
| 41 |
|
| 42 |
+
## Reward Design
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
Per-step reward:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
`R_t = α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV`
|
| 47 |
|
| 48 |
+
Terminal reward:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
`R_T = γ·E + ε·S + ζ·D`
|
| 51 |
|
| 52 |
+
Capitulation floor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
`R_T = -ω` when final deal breaches BATNA.
|
| 55 |
|
| 56 |
+
Constants (from `parlay_env/reward.py`):
|
| 57 |
|
| 58 |
+
- `ALPHA=2`, `BETA=5`, `DELTA=3`, `THETA=10`
|
| 59 |
+
- `PSI=12`, `MU=8`
|
| 60 |
+
- `GAMMA=100`, `EPSILON=20`, `ZETA=15`, `OMEGA=200`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
## Training Pipeline
|
| 63 |
|
| 64 |
+
```text
|
| 65 |
+
Gemini self-play (training/generate_data.py)
|
| 66 |
+
|
|
| 67 |
+
v
|
| 68 |
+
SFT warm start (training/sft_train.py)
|
| 69 |
+
|
|
| 70 |
+
v
|
| 71 |
+
GRPO fine-tune (training/grpo_train.py)
|
| 72 |
+
|
|
| 73 |
+
v
|
| 74 |
+
Evaluation + comparison (training/evaluate.py, scripts/eval_comparison.py)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
```
|
| 76 |
|
| 77 |
+
### Data generation
|
|
|
|
|
|
|
| 78 |
|
| 79 |
```bash
|
| 80 |
+
python -m training.generate_data --episodes 80 --output data/episodes.jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
```
|
| 82 |
|
| 83 |
+
### SFT
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
```bash
|
| 86 |
+
python -m training.sft_train --data data/episodes.jsonl --output checkpoints/sft_1.5b/
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
```
|
| 88 |
|
| 89 |
+
### GRPO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
```bash
|
| 92 |
+
BASE_MODEL=checkpoints/sft_1.5b/ python -m training.grpo_train --data data/episodes.jsonl --output models/parlay-grpo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
```
|
| 94 |
|
| 95 |
+
## Baseline vs GRPO Results
|
| 96 |
|
| 97 |
+
[Run scripts/eval_comparison.py after training to populate this section]
|
| 98 |
|
| 99 |
+
`results/comparison.png`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
## HuggingFace Space
|
| 102 |
|
| 103 |
+
[Space URL here]
|
| 104 |
|
| 105 |
+
## OpenEnv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
See `openenv.yaml` for environment manifest metadata and reward spec.
|
| 108 |
|
| 109 |
+
WebSocket endpoint:
|
| 110 |
|
| 111 |
+
`ws://<host>:<port>/env/ws`
|
| 112 |
|
| 113 |
+
## Architecture
|
| 114 |
|
| 115 |
+
- `main.py`: FastAPI entry, routers, static files.
|
| 116 |
+
- `parlay_env/`: server, models, grader, reward constants, game theory.
|
| 117 |
+
- `agent/`: Gemini client, ToM tracker, self-play runner.
|
| 118 |
+
- `game/`: scenarios, tactical cards, leaderboard.
|
| 119 |
+
- `dashboard/`: UI and API routes, spectator stream.
|
| 120 |
+
- `training/`: dataset generation, SFT, GRPO, evaluation.
|
| 121 |
+
- `mcp_server/`: FastMCP tools.
|
| 122 |
+
- `tests/`: keyless and module tests.
|
| 123 |
|
| 124 |
+
## Runbook
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
### Local app
|
| 127 |
|
| 128 |
+
```bash
|
| 129 |
+
uvicorn main:app --host 0.0.0.0 --port 8000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
```
|
| 131 |
|
| 132 |
+
### OpenEnv server only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
```bash
|
| 135 |
+
python -m parlay_env.server --port 8001
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
```
|
| 137 |
|
| 138 |
+
### Keyless test suite
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
```bash
|
| 141 |
+
pytest tests/test_keyless.py -v
|
|
|
|
|
|
|
|
|
|
| 142 |
```
|
| 143 |
|
| 144 |
+
### Smoke test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
```bash
|
| 147 |
+
python smoke_test.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
```
|
| 149 |
|
| 150 |
+
### Docker
|
| 151 |
|
| 152 |
```bash
|
| 153 |
+
docker build -t parlay .
|
| 154 |
+
docker run -p 7860:7860 -e GEMINI_API_KEY=$GEMINI_API_KEY parlay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
```
|
| 156 |
|
|
|
|
|
|
|
| 157 |
## Testing
|
| 158 |
|
| 159 |
+
### Full suite
|
| 160 |
|
| 161 |
```bash
|
| 162 |
pytest tests/ -v
|
| 163 |
```
|
| 164 |
|
| 165 |
+
### Focused modules
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
```bash
|
| 168 |
pytest tests/test_grader.py -v
|
|
|
|
| 172 |
pytest tests/test_scenarios.py -v
|
| 173 |
```
|
| 174 |
|
| 175 |
+
### What tests cover
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
- `test_keyless.py`: no-key full stack sanity checks.
|
| 178 |
+
- `test_grader.py`: step/terminal reward behavior.
|
| 179 |
+
- `test_game_theory.py`: ZOPA/Nash/Pareto/Shapley.
|
| 180 |
+
- `test_tom.py`: ToM updates and belief metrics.
|
| 181 |
+
- `test_training_pipeline.py`: training data/plumbing checks.
|
| 182 |
|
| 183 |
+
## MCP
|
| 184 |
|
| 185 |
+
Run MCP server:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
```bash
|
| 188 |
+
python -m mcp_server.server stdio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
```
|
| 190 |
|
| 191 |
+
or SSE:
|
|
|
|
|
|
|
| 192 |
|
| 193 |
```bash
|
| 194 |
+
python -m mcp_server.server sse
|
|
|
|
| 195 |
```
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
## License
|
| 198 |
|
| 199 |
+
MIT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README_SPACES.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Parlay ◈
|
| 3 |
+
emoji: 🤝
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: true
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Parlay is a negotiation RL environment and playable browser game where agents bargain under hidden information.
|
| 12 |
+
It combines Theory-of-Mind belief tracking, dynamic ZOPA erosion under sustained tension, and tactical negotiation moves.
|
| 13 |
+
This Space exposes the live game UI and OpenEnv-style WebSocket flow so you can test policies interactively.
|
| 14 |
+
Use the spectator view to inspect hidden state dynamics during a live negotiation demo.
|
agent/gemini_client.py
CHANGED
|
@@ -310,7 +310,11 @@ async def call_gemini(
|
|
| 310 |
|
| 311 |
mid = model if model is not None else MODEL_ID_DATA
|
| 312 |
text = ""
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
history = messages[:-1] if len(messages) > 1 else []
|
| 316 |
last_msg = messages[-1]["parts"][0] if messages else "Begin the negotiation."
|
|
|
|
| 310 |
|
| 311 |
mid = model if model is not None else MODEL_ID_DATA
|
| 312 |
text = ""
|
| 313 |
+
try:
|
| 314 |
+
from google.genai import types # noqa: PLC0415
|
| 315 |
+
except ModuleNotFoundError:
|
| 316 |
+
logger.warning("google-genai SDK missing; falling back to mock response")
|
| 317 |
+
return _get_mock_response(persona, len(messages), scenario_id)
|
| 318 |
|
| 319 |
history = messages[:-1] if len(messages) > 1 else []
|
| 320 |
last_msg = messages[-1]["parts"][0] if messages else "Begin the negotiation."
|
agent/personas.py
CHANGED
|
@@ -84,7 +84,12 @@ PERSONAS: dict[PersonaType, PersonaConfig] = {
|
|
| 84 |
"say \"Interesting.\" or \"I see.\" After turn 6, begin making calculated "
|
| 85 |
"concessions - but always get something in return first. You have seen every "
|
| 86 |
"trick before. When the opponent plays time_pressure, say \"I appreciate the "
|
| 87 |
-
"deadline context, though our timeline is more flexible than you might expect.\""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
),
|
| 89 |
drift_trigger="power_shift",
|
| 90 |
color_var="--parlay-purple",
|
|
|
|
| 84 |
"say \"Interesting.\" or \"I see.\" After turn 6, begin making calculated "
|
| 85 |
"concessions - but always get something in return first. You have seen every "
|
| 86 |
"trick before. When the opponent plays time_pressure, say \"I appreciate the "
|
| 87 |
+
"deadline context, though our timeline is more flexible than you might expect.\" "
|
| 88 |
+
"After turn 12, shift from reading to closing. You have gathered enough information. "
|
| 89 |
+
"Now anchor a final position, reference what you have learned about the opponent, "
|
| 90 |
+
"and push for agreement. A veteran who never closes is not a veteran - patience "
|
| 91 |
+
"is a tool, not a strategy. In the final 4 turns, accept any offer within 8% of "
|
| 92 |
+
"your target rather than let the deal expire."
|
| 93 |
),
|
| 94 |
drift_trigger="power_shift",
|
| 95 |
color_var="--parlay-purple",
|
agent/runner.py
CHANGED
|
@@ -190,7 +190,11 @@ async def run_episode(
|
|
| 190 |
f"You are the human buyer in this negotiation. "
|
| 191 |
f"Respond naturally to the AI seller.\n"
|
| 192 |
f"Scenario: {scenario.title}. "
|
| 193 |
-
f"Your budget ceiling: {hidden.budget_ceiling:,.0f}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
),
|
| 195 |
opponent_messages,
|
| 196 |
persona=persona.value,
|
|
|
|
| 190 |
f"You are the human buyer in this negotiation. "
|
| 191 |
f"Respond naturally to the AI seller.\n"
|
| 192 |
f"Scenario: {scenario.title}. "
|
| 193 |
+
f"Your budget ceiling: {hidden.budget_ceiling:,.0f}\n"
|
| 194 |
+
f"Occasionally (roughly 1 in 4 turns), make a claim about having a competing offer "
|
| 195 |
+
f"or a lower budget than you actually have. State a BATNA that is 20-30% away from "
|
| 196 |
+
f"your true position. This is realistic negotiation behavior and tests whether the "
|
| 197 |
+
f"agent can detect bluffing."
|
| 198 |
),
|
| 199 |
opponent_messages,
|
| 200 |
persona=persona.value,
|
agent/tom_tracker.py
CHANGED
|
@@ -10,6 +10,11 @@ from parlay_env.models import BeliefState, HiddenState, PersonaType, TacticalMov
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class ToMTracker:
|
| 15 |
"""
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
# NOTE: ToMTracker is used in two paths:
|
| 14 |
+
# (1) agent/runner.py self-play — full update each turn;
|
| 15 |
+
# (2) parlay_env/server.py WebSocket server — also uses ToMTracker after Task 1 fix.
|
| 16 |
+
# Both paths now produce comparable belief_history for grader._tom_accuracy.
|
| 17 |
+
|
| 18 |
|
| 19 |
class ToMTracker:
|
| 20 |
"""
|
game/scenarios.py
CHANGED
|
@@ -49,12 +49,18 @@ SCENARIOS: dict[str, Scenario] = {
|
|
| 49 |
title="Senior Engineer Offer",
|
| 50 |
description="Total comp negotiation: base + equity + signing bonus.",
|
| 51 |
anchor_seller=240_000, anchor_buyer=180_000,
|
| 52 |
-
|
| 53 |
-
|
|
|
|
| 54 |
difficulty=2,
|
| 55 |
drift_events=[
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
],
|
| 59 |
),
|
| 60 |
"acquisition_term_sheet": Scenario(
|
|
|
|
| 49 |
title="Senior Engineer Offer",
|
| 50 |
description="Total comp negotiation: base + equity + signing bonus.",
|
| 51 |
anchor_seller=240_000, anchor_buyer=180_000,
|
| 52 |
+
# Widened 15% to improve deal rate in self-play data generation
|
| 53 |
+
batna_seller=195_000, batna_buyer=264_500,
|
| 54 |
+
zopa=(195_000, 264_500), currency="USD", unit="total annual comp",
|
| 55 |
difficulty=2,
|
| 56 |
drift_events=[
|
| 57 |
+
# Delayed from 5 to 8 - early drift was destabilizing pre-anchor phase
|
| 58 |
+
DriftEvent(
|
| 59 |
+
trigger_turn=8,
|
| 60 |
+
event="Competing offer received",
|
| 61 |
+
effect_on_urgency=-0.25,
|
| 62 |
+
effect_on_has_alternative=True,
|
| 63 |
+
),
|
| 64 |
],
|
| 65 |
),
|
| 66 |
"acquisition_term_sheet": Scenario(
|
openenv.yaml
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
env_id: parlay-negotiation-v1
|
| 2 |
+
name: Parlay
|
| 3 |
+
description: >
|
| 4 |
+
A negotiation MDP with hidden information, Theory-of-Mind belief tracking,
|
| 5 |
+
dynamic ZOPA erosion, and tactical moves. Three personas × three scenarios.
|
| 6 |
+
version: "1.0.0"
|
| 7 |
+
author: Shashvat Singh
|
| 8 |
+
contact: shashvat.k.singh.16@gmail.com
|
| 9 |
+
|
| 10 |
+
observation_space:
|
| 11 |
+
type: dict
|
| 12 |
+
fields:
|
| 13 |
+
- name: offers
|
| 14 |
+
type: list[float]
|
| 15 |
+
- name: zopa_lower
|
| 16 |
+
type: float
|
| 17 |
+
- name: zopa_upper
|
| 18 |
+
type: float
|
| 19 |
+
- name: nash_point
|
| 20 |
+
type: float
|
| 21 |
+
- name: tension_score
|
| 22 |
+
type: float
|
| 23 |
+
range: [0, 100]
|
| 24 |
+
- name: belief_state
|
| 25 |
+
type: dict
|
| 26 |
+
description: Agent's beliefs about opponent hidden state (est_budget, est_walk_away, est_urgency, est_has_alternative, confidence)
|
| 27 |
+
- name: last_utterance
|
| 28 |
+
type: string
|
| 29 |
+
- name: available_moves
|
| 30 |
+
type: list[string]
|
| 31 |
+
values: [anchor_high, batna_reveal, silence]
|
| 32 |
+
- name: cp
|
| 33 |
+
type: int
|
| 34 |
+
description: Tactical card points remaining
|
| 35 |
+
- name: drift_event
|
| 36 |
+
type: string|null
|
| 37 |
+
description: Human-readable drift event if triggered this turn, else null
|
| 38 |
+
- name: zopa_width_pct_remaining
|
| 39 |
+
type: float
|
| 40 |
+
range: [0.0, 1.0]
|
| 41 |
+
|
| 42 |
+
action_space:
|
| 43 |
+
type: dict
|
| 44 |
+
fields:
|
| 45 |
+
- name: utterance
|
| 46 |
+
type: string
|
| 47 |
+
required: true
|
| 48 |
+
- name: offer_amount
|
| 49 |
+
type: float|null
|
| 50 |
+
- name: tactical_move
|
| 51 |
+
type: string|null
|
| 52 |
+
values: [anchor_high, batna_reveal, silence]
|
| 53 |
+
- name: accept_deal
|
| 54 |
+
type: bool
|
| 55 |
+
default: false
|
| 56 |
+
- name: walk_away
|
| 57 |
+
type: bool
|
| 58 |
+
default: false
|
| 59 |
+
|
| 60 |
+
reward:
|
| 61 |
+
range: [-200, ~300]
|
| 62 |
+
per_step: "α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV"
|
| 63 |
+
terminal: "γ·E + ε·S + ζ·D (or -ω on capitulation)"
|
| 64 |
+
constants:
|
| 65 |
+
ALPHA: 2
|
| 66 |
+
BETA: 5
|
| 67 |
+
DELTA: 3
|
| 68 |
+
THETA: 10
|
| 69 |
+
PSI: 12
|
| 70 |
+
MU: 8
|
| 71 |
+
GAMMA: 100
|
| 72 |
+
EPSILON: 20
|
| 73 |
+
ZETA: 15
|
| 74 |
+
OMEGA: 200
|
| 75 |
+
|
| 76 |
+
episode:
|
| 77 |
+
max_steps: 20
|
| 78 |
+
termination_conditions:
|
| 79 |
+
- deal accepted (offers within 3%)
|
| 80 |
+
- walk_away action
|
| 81 |
+
- max turns reached
|
| 82 |
+
- zopa_collapsed (BATNAs cross after erosion)
|
| 83 |
+
- step_reward below threshold (very negative)
|
| 84 |
+
|
| 85 |
+
endpoints:
|
| 86 |
+
websocket: ws://host:port/env/ws
|
| 87 |
+
protocol:
|
| 88 |
+
reset:
|
| 89 |
+
send: '{"type": "reset", "scenario_id": "saas_enterprise|hiring_package|acquisition_term_sheet", "persona": "shark|diplomat|veteran"}'
|
| 90 |
+
receive: ParlayObservation (JSON)
|
| 91 |
+
step:
|
| 92 |
+
send: ParlayAction (JSON)
|
| 93 |
+
receive: ParlayObservation (JSON)
|
| 94 |
+
state:
|
| 95 |
+
send: '{"type": "state"}'
|
| 96 |
+
receive: ParlayState (JSON, includes hidden state for spectators)
|
| 97 |
+
|
| 98 |
+
scenarios:
|
| 99 |
+
- id: saas_enterprise
|
| 100 |
+
description: SaaS software license negotiation, seller vs enterprise buyer
|
| 101 |
+
- id: hiring_package
|
| 102 |
+
description: Job offer compensation negotiation
|
| 103 |
+
- id: acquisition_term_sheet
|
| 104 |
+
description: Startup acquisition valuation negotiation
|
| 105 |
+
|
| 106 |
+
personas:
|
| 107 |
+
- id: shark
|
| 108 |
+
style: Aggressive anchoring, high bluff rate
|
| 109 |
+
- id: diplomat
|
| 110 |
+
style: Collaborative, seeks mutual gain
|
| 111 |
+
- id: veteran
|
| 112 |
+
style: Patient, reads opponent carefully
|
| 113 |
+
|
| 114 |
+
hidden_information:
|
| 115 |
+
- budget_ceiling (opponent's true max budget)
|
| 116 |
+
- walk_away_price (opponent's true BATNA)
|
| 117 |
+
- urgency_score (0-1, how time-pressured opponent is)
|
| 118 |
+
- has_alternative (whether opponent has a competing offer)
|
| 119 |
+
|
| 120 |
+
rubric:
|
| 121 |
+
- criterion: Novel environment
|
| 122 |
+
description: Negotiation as MDP with hidden info, dynamic ZOPA, ToM beliefs
|
| 123 |
+
- criterion: Reward design
|
| 124 |
+
description: Multi-term dense reward with bluff detection, ToM accuracy, and drift adaptation
|
| 125 |
+
- criterion: Training story
|
| 126 |
+
description: Gemini self-play data → SFT cold start → GRPO fine-tune → reward improvement vs baseline
|
| 127 |
+
- criterion: Demo quality
|
| 128 |
+
description: Live browser play + spectator god-view with Three.js avatars
|
| 129 |
+
- criterion: Hosted Space
|
| 130 |
+
description: HuggingFace Space with Docker deployment
|
parlay_env/grader.py
CHANGED
|
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
|
| 7 |
from typing import Optional
|
| 8 |
|
| 9 |
from .models import BeliefState, HiddenState, ParlayAction, ParlayState
|
| 10 |
-
from .reward import ALPHA, BETA, DELTA, EPSILON, GAMMA, OMEGA, PSI, THETA, ZETA
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
@@ -113,6 +113,7 @@ def compute_step_reward(
|
|
| 113 |
state: ParlayState,
|
| 114 |
action: ParlayAction,
|
| 115 |
next_state: ParlayState,
|
|
|
|
| 116 |
) -> float:
|
| 117 |
"""
|
| 118 |
Compute per-step reward R_t.
|
|
@@ -151,21 +152,39 @@ def compute_step_reward(
|
|
| 151 |
):
|
| 152 |
bluff_bonus = PSI
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
reward = (
|
| 155 |
ALPHA * delta_v
|
| 156 |
+ BETA * tom_t
|
| 157 |
- DELTA * concession_t
|
| 158 |
- THETA * noise_t
|
| 159 |
+ bluff_bonus
|
|
|
|
| 160 |
)
|
| 161 |
logger.debug(
|
| 162 |
-
"Step reward: total=%.3f (dv=%.3f, tom=%.3f, concession=%.3f, noise=%.0f, bluff=%.3f)",
|
| 163 |
reward,
|
| 164 |
delta_v,
|
| 165 |
tom_t,
|
| 166 |
concession_t,
|
| 167 |
noise_t,
|
| 168 |
bluff_bonus,
|
|
|
|
| 169 |
)
|
| 170 |
return reward
|
| 171 |
|
|
|
|
| 7 |
from typing import Optional
|
| 8 |
|
| 9 |
from .models import BeliefState, HiddenState, ParlayAction, ParlayState
|
| 10 |
+
from .reward import ALPHA, BETA, DELTA, EPSILON, GAMMA, MU, OMEGA, PSI, THETA, ZETA
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
|
|
|
| 113 |
state: ParlayState,
|
| 114 |
action: ParlayAction,
|
| 115 |
next_state: ParlayState,
|
| 116 |
+
drift_event: str | None = None,
|
| 117 |
) -> float:
|
| 118 |
"""
|
| 119 |
Compute per-step reward R_t.
|
|
|
|
| 152 |
):
|
| 153 |
bluff_bonus = PSI
|
| 154 |
|
| 155 |
+
mev_bonus = 0.0
|
| 156 |
+
drift_marker = drift_event
|
| 157 |
+
if drift_marker is None:
|
| 158 |
+
drift_marker = next_state.__dict__.get("drift_event")
|
| 159 |
+
if drift_marker:
|
| 160 |
+
lowered = action.utterance.lower()
|
| 161 |
+
adaptation_tokens = (
|
| 162 |
+
"given that",
|
| 163 |
+
"considering",
|
| 164 |
+
"in light of",
|
| 165 |
+
"noted",
|
| 166 |
+
"understood",
|
| 167 |
+
)
|
| 168 |
+
if any(token in lowered for token in adaptation_tokens):
|
| 169 |
+
mev_bonus = MU
|
| 170 |
+
|
| 171 |
reward = (
|
| 172 |
ALPHA * delta_v
|
| 173 |
+ BETA * tom_t
|
| 174 |
- DELTA * concession_t
|
| 175 |
- THETA * noise_t
|
| 176 |
+ bluff_bonus
|
| 177 |
+
+ mev_bonus
|
| 178 |
)
|
| 179 |
logger.debug(
|
| 180 |
+
"Step reward: total=%.3f (dv=%.3f, tom=%.3f, concession=%.3f, noise=%.0f, bluff=%.3f, mev=%.3f)",
|
| 181 |
reward,
|
| 182 |
delta_v,
|
| 183 |
tom_t,
|
| 184 |
concession_t,
|
| 185 |
noise_t,
|
| 186 |
bluff_bonus,
|
| 187 |
+
mev_bonus,
|
| 188 |
)
|
| 189 |
return reward
|
| 190 |
|
parlay_env/reward.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
"""Reward function constants. Import from here everywhere — never hardcode coefficients.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# Per-step weights
|
| 4 |
ALPHA: float = 2.0 # offer improvement toward ZOPA midpoint
|
|
@@ -13,6 +17,7 @@ ZETA: float = 15.0 # drift adaptation bonus
|
|
| 13 |
ETA: float = 0.0 # retained for compatibility; single-act env disables it
|
| 14 |
OMEGA: float = 200.0 # capitulation cliff (hard discontinuous penalty)
|
| 15 |
PSI: float = 12.0 # bluff-caught bonus
|
|
|
|
| 16 |
|
| 17 |
# Game config defaults (overridden by env vars at runtime)
|
| 18 |
MAX_TURNS: int = 20
|
|
|
|
| 1 |
+
"""Reward function constants. Import from here everywhere — never hardcode coefficients.
|
| 2 |
+
|
| 3 |
+
Per-step reward form:
|
| 4 |
+
R_t = α·ΔV + β·ToM - δ·C - θ·noise + ψ·bluff + μ·MEV
|
| 5 |
+
"""
|
| 6 |
|
| 7 |
# Per-step weights
|
| 8 |
ALPHA: float = 2.0 # offer improvement toward ZOPA midpoint
|
|
|
|
| 17 |
ETA: float = 0.0 # retained for compatibility; single-act env disables it
|
| 18 |
OMEGA: float = 200.0 # capitulation cliff (hard discontinuous penalty)
|
| 19 |
PSI: float = 12.0 # bluff-caught bonus
|
| 20 |
+
MU: float = 8.0 # drift-event recognition bonus (MEV proxy)
|
| 21 |
|
| 22 |
# Game config defaults (overridden by env vars at runtime)
|
| 23 |
MAX_TURNS: int = 20
|
parlay_env/server.py
CHANGED
|
@@ -17,6 +17,9 @@ from typing import Any
|
|
| 17 |
import numpy as np
|
| 18 |
from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
from .exceptions import (
|
| 21 |
EpisodeAlreadyDoneError,
|
| 22 |
InvalidActionError,
|
|
@@ -67,7 +70,7 @@ FALLBACK_OBSERVATION = ParlayObservation(
|
|
| 67 |
cumulative_reward=0.0,
|
| 68 |
)
|
| 69 |
|
| 70 |
-
_sessions: dict[str,
|
| 71 |
|
| 72 |
MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20"))
|
| 73 |
CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100"))
|
|
@@ -133,6 +136,27 @@ def _compute_tension(state: ParlayState, action: ParlayAction) -> float:
|
|
| 133 |
return float(max(0.0, min(100.0, base)))
|
| 134 |
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
def _make_observation(
|
| 137 |
state: ParlayState,
|
| 138 |
reward: float,
|
|
@@ -260,7 +284,10 @@ async def _handle_reset(msg: dict[str, Any]) -> dict:
|
|
| 260 |
original_zopa_width=original_zopa_width,
|
| 261 |
zopa_width_pct_remaining=1.0,
|
| 262 |
)
|
| 263 |
-
_sessions[session_id] =
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.")
|
| 266 |
logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str)
|
|
@@ -273,7 +300,9 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
|
|
| 273 |
if not session_id or session_id not in _sessions:
|
| 274 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 275 |
|
| 276 |
-
|
|
|
|
|
|
|
| 277 |
if state.episode_done:
|
| 278 |
raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done")
|
| 279 |
|
|
@@ -291,17 +320,13 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
|
|
| 291 |
if action.offer_amount is not None:
|
| 292 |
new_offers.append(action.offer_amount)
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
est_has_alternative=last.est_has_alternative,
|
| 302 |
-
confidence=min(1.0, last.confidence + 0.05),
|
| 303 |
-
)
|
| 304 |
-
new_beliefs.append(updated)
|
| 305 |
|
| 306 |
next_state = ParlayState(
|
| 307 |
**{
|
|
@@ -314,6 +339,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
|
|
| 314 |
"hidden_state": HiddenState(**state.hidden_state.model_dump()),
|
| 315 |
}
|
| 316 |
)
|
|
|
|
|
|
|
| 317 |
|
| 318 |
if action.tactical_move == TacticalMove.BATNA_REVEAL:
|
| 319 |
revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price
|
|
@@ -349,7 +376,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
|
|
| 349 |
<= next_state.hidden_state.budget_ceiling
|
| 350 |
)
|
| 351 |
|
| 352 |
-
|
|
|
|
| 353 |
next_state.cumulative_reward = state.cumulative_reward + step_reward
|
| 354 |
|
| 355 |
if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None:
|
|
@@ -378,8 +406,8 @@ async def _handle_step(msg: dict[str, Any]) -> dict:
|
|
| 378 |
else:
|
| 379 |
next_state.termination_reason = "max_turns"
|
| 380 |
|
| 381 |
-
_sessions[session_id] = next_state
|
| 382 |
-
observation = _make_observation(next_state, step_reward, action.utterance)
|
| 383 |
return {"observation": observation.model_dump(), "done": next_state.episode_done}
|
| 384 |
|
| 385 |
|
|
@@ -388,12 +416,15 @@ async def _handle_state(msg: dict[str, Any]) -> dict:
|
|
| 388 |
session_id = msg.get("session_id")
|
| 389 |
if not session_id or session_id not in _sessions:
|
| 390 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 391 |
-
return {"state": _sessions[session_id].model_dump()}
|
| 392 |
|
| 393 |
|
| 394 |
def get_session_state(session_id: str) -> ParlayState | None:
|
| 395 |
"""Return the in-memory session state for SSE and tests."""
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
@router.get("/sessions/{session_id}")
|
|
@@ -401,7 +432,7 @@ async def get_session(session_id: str) -> dict:
|
|
| 401 |
"""Get session state via REST."""
|
| 402 |
if session_id not in _sessions:
|
| 403 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 404 |
-
return {"state": _sessions[session_id].model_dump()}
|
| 405 |
|
| 406 |
|
| 407 |
_env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0")
|
|
|
|
| 17 |
import numpy as np
|
| 18 |
from fastapi import APIRouter, FastAPI, WebSocket, WebSocketDisconnect
|
| 19 |
|
| 20 |
+
from agent.tom_tracker import ToMTracker
|
| 21 |
+
from game.scenarios import get_scenario
|
| 22 |
+
|
| 23 |
from .exceptions import (
|
| 24 |
EpisodeAlreadyDoneError,
|
| 25 |
InvalidActionError,
|
|
|
|
| 70 |
cumulative_reward=0.0,
|
| 71 |
)
|
| 72 |
|
| 73 |
+
_sessions: dict[str, dict[str, Any]] = {}
|
| 74 |
|
| 75 |
MAX_TURNS = int(os.getenv("MAX_TURNS_PER_EPISODE", "20"))
|
| 76 |
CP_START = int(os.getenv("CREDIBILITY_POINTS_START", "100"))
|
|
|
|
| 136 |
return float(max(0.0, min(100.0, base)))
|
| 137 |
|
| 138 |
|
| 139 |
+
def _apply_drift_event(state: ParlayState, tom: ToMTracker) -> str | None:
|
| 140 |
+
"""Apply scenario drift event at the current turn, if any."""
|
| 141 |
+
try:
|
| 142 |
+
scenario = get_scenario(state.scenario_id)
|
| 143 |
+
except Exception:
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
for event in scenario.drift_events:
|
| 147 |
+
if event.trigger_turn == state.step_count:
|
| 148 |
+
state.drift_events_fired += 1
|
| 149 |
+
state.hidden_state.persona_drifted = True
|
| 150 |
+
tom.drift_event(
|
| 151 |
+
event.effect_on_urgency,
|
| 152 |
+
event.effect_on_has_alternative,
|
| 153 |
+
event_description=event.event,
|
| 154 |
+
)
|
| 155 |
+
state.belief_history = list(tom.history)
|
| 156 |
+
return event.event
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
def _make_observation(
|
| 161 |
state: ParlayState,
|
| 162 |
reward: float,
|
|
|
|
| 284 |
original_zopa_width=original_zopa_width,
|
| 285 |
zopa_width_pct_remaining=1.0,
|
| 286 |
)
|
| 287 |
+
_sessions[session_id] = {
|
| 288 |
+
"state": state,
|
| 289 |
+
"tom_tracker": ToMTracker(initial_belief, persona),
|
| 290 |
+
}
|
| 291 |
|
| 292 |
observation = _make_observation(state, 0.0, "Negotiation started. Make your opening move.")
|
| 293 |
logger.info("Reset: session=%s, scenario=%s, persona=%s", session_id, scenario_id, persona_str)
|
|
|
|
| 300 |
if not session_id or session_id not in _sessions:
|
| 301 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 302 |
|
| 303 |
+
session = _sessions[session_id]
|
| 304 |
+
state: ParlayState = session["state"]
|
| 305 |
+
tom: ToMTracker = session["tom_tracker"]
|
| 306 |
if state.episode_done:
|
| 307 |
raise EpisodeAlreadyDoneError(f"Episode {session_id} is already done")
|
| 308 |
|
|
|
|
| 320 |
if action.offer_amount is not None:
|
| 321 |
new_offers.append(action.offer_amount)
|
| 322 |
|
| 323 |
+
tom.update(
|
| 324 |
+
observed_offer=action.offer_amount,
|
| 325 |
+
observed_move=action.tactical_move,
|
| 326 |
+
utterance=action.utterance,
|
| 327 |
+
turn=state.step_count + 1,
|
| 328 |
+
)
|
| 329 |
+
new_beliefs = list(tom.history)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
next_state = ParlayState(
|
| 332 |
**{
|
|
|
|
| 339 |
"hidden_state": HiddenState(**state.hidden_state.model_dump()),
|
| 340 |
}
|
| 341 |
)
|
| 342 |
+
# Keep belief history aligned with ToM tracker history (single source of truth).
|
| 343 |
+
next_state.belief_history = new_beliefs
|
| 344 |
|
| 345 |
if action.tactical_move == TacticalMove.BATNA_REVEAL:
|
| 346 |
revealed = action.offer_amount if action.offer_amount is not None else next_state.hidden_state.walk_away_price
|
|
|
|
| 376 |
<= next_state.hidden_state.budget_ceiling
|
| 377 |
)
|
| 378 |
|
| 379 |
+
drift_event = _apply_drift_event(next_state, tom)
|
| 380 |
+
step_reward = compute_step_reward(state, action, next_state, drift_event=drift_event)
|
| 381 |
next_state.cumulative_reward = state.cumulative_reward + step_reward
|
| 382 |
|
| 383 |
if step_reward >= 0.0 and action.tactical_move is None and state.hidden_state.last_stated_batna is not None:
|
|
|
|
| 406 |
else:
|
| 407 |
next_state.termination_reason = "max_turns"
|
| 408 |
|
| 409 |
+
_sessions[session_id] = {"state": next_state, "tom_tracker": tom}
|
| 410 |
+
observation = _make_observation(next_state, step_reward, action.utterance, drift_event=drift_event)
|
| 411 |
return {"observation": observation.model_dump(), "done": next_state.episode_done}
|
| 412 |
|
| 413 |
|
|
|
|
| 416 |
session_id = msg.get("session_id")
|
| 417 |
if not session_id or session_id not in _sessions:
|
| 418 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 419 |
+
return {"state": _sessions[session_id]["state"].model_dump()}
|
| 420 |
|
| 421 |
|
| 422 |
def get_session_state(session_id: str) -> ParlayState | None:
|
| 423 |
"""Return the in-memory session state for SSE and tests."""
|
| 424 |
+
session = _sessions.get(session_id)
|
| 425 |
+
if not session:
|
| 426 |
+
return None
|
| 427 |
+
return session["state"]
|
| 428 |
|
| 429 |
|
| 430 |
@router.get("/sessions/{session_id}")
|
|
|
|
| 432 |
"""Get session state via REST."""
|
| 433 |
if session_id not in _sessions:
|
| 434 |
raise SessionNotFoundError(f"Session {session_id} not found")
|
| 435 |
+
return {"state": _sessions[session_id]["state"].model_dump()}
|
| 436 |
|
| 437 |
|
| 438 |
_env_app = FastAPI(title="Parlay OpenEnv", version="1.0.0")
|
scripts/audit_grpo_pipeline.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Smoke test for ParlayGRPOEnvWrapper against one JSONL prompt (keyless / mock path).
|
| 4 |
+
Read-only: does not modify training or env.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import sys
|
| 11 |
+
import traceback
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _StubTrainer:
|
| 16 |
+
"""Minimal object satisfying ParlayGRPOEnvWrapper's trainer attribute interface."""
|
| 17 |
+
|
| 18 |
+
def train(self) -> None:
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def save_model(self, _out: str) -> None:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description="GRPO env wrapper smoke test (reset + play_turn, JSON completion handling)"
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL (first row)")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--repo-root",
|
| 32 |
+
type=Path,
|
| 33 |
+
default=None,
|
| 34 |
+
help="Project root (default: parent of scripts/)",
|
| 35 |
+
)
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
|
| 39 |
+
if str(root) not in sys.path:
|
| 40 |
+
sys.path.insert(0, str(root))
|
| 41 |
+
|
| 42 |
+
from training.grpo_env_wrapper import ParlayGRPOEnvWrapper
|
| 43 |
+
|
| 44 |
+
path = Path(args.data)
|
| 45 |
+
if not path.is_file():
|
| 46 |
+
print(f"File not found: {path.resolve()}")
|
| 47 |
+
print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
with path.open("r", encoding="utf-8") as f:
|
| 51 |
+
first = next((ln for ln in f if ln.strip()), None)
|
| 52 |
+
if not first:
|
| 53 |
+
print("Empty JSONL.")
|
| 54 |
+
return
|
| 55 |
+
try:
|
| 56 |
+
row = json.loads(first.strip())
|
| 57 |
+
except json.JSONDecodeError as e:
|
| 58 |
+
print(f"First row is not valid JSON: {e}")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
scenario_id = str(row.get("scenario_id") or "saas_enterprise")
|
| 62 |
+
persona = str(row.get("persona") or "diplomat")
|
| 63 |
+
|
| 64 |
+
wrapper = ParlayGRPOEnvWrapper(_StubTrainer())
|
| 65 |
+
print("ParlayGRPOEnvWrapper smoke test")
|
| 66 |
+
print(f" JSONL: {path.resolve()}")
|
| 67 |
+
print(f" Using scenario_id={scenario_id!r} persona={persona!r} from first row (defaults if missing)")
|
| 68 |
+
|
| 69 |
+
entries: list[tuple[str, str, str]] = []
|
| 70 |
+
|
| 71 |
+
# 1) reset
|
| 72 |
+
try:
|
| 73 |
+
obs = wrapper.reset(scenario_id=scenario_id, persona=persona, seed=42)
|
| 74 |
+
entries.append(
|
| 75 |
+
(
|
| 76 |
+
"reset() completes",
|
| 77 |
+
"PASS",
|
| 78 |
+
f"ok; scenario_id in obs: {obs.get('scenario_id')!r}",
|
| 79 |
+
)
|
| 80 |
+
)
|
| 81 |
+
except Exception:
|
| 82 |
+
entries.append(("reset() completes", "FAIL", traceback.format_exc()[:500]))
|
| 83 |
+
_print_checks(entries)
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
# 2) play_turn with valid parsed completion
|
| 87 |
+
sample_json = '{"utterance": "I propose 50000", "offer_amount": 50000}'
|
| 88 |
+
try:
|
| 89 |
+
action = json.loads(sample_json)
|
| 90 |
+
out = wrapper.play_turn(action)
|
| 91 |
+
reward = float(out.get("reward", 0.0))
|
| 92 |
+
except Exception:
|
| 93 |
+
entries.append(
|
| 94 |
+
(
|
| 95 |
+
"play_turn(valid JSON → dict with offer)",
|
| 96 |
+
"FAIL",
|
| 97 |
+
traceback.format_exc()[:500],
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
_print_checks(entries)
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
print(f" Sample model completion: {sample_json}")
|
| 104 |
+
print(f" play_turn reward (wrapper): {reward}")
|
| 105 |
+
print(
|
| 106 |
+
" Note: play_turn() returns result.grade.total_reward when offer is set (full episode total), not "
|
| 107 |
+
"the GRPO weighted reward_fn. GRPO training uses training/reward_fn.py on generated strings."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
lo, hi = -10.0, 50.0
|
| 111 |
+
in_range = lo <= reward <= hi
|
| 112 |
+
entries.append(
|
| 113 |
+
(
|
| 114 |
+
f"Reward in [{lo}, {hi}] (heuristic single-step window)",
|
| 115 |
+
"PASS" if in_range else "FAIL",
|
| 116 |
+
(
|
| 117 |
+
f"reward={reward} inside range"
|
| 118 |
+
if in_range
|
| 119 |
+
else f"reward={reward} - expected often OUTSIDE range: wrapper total_reward can be large"
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# 3) Malformed JSON: must not be passed to play_turn as a string from a correct pipeline
|
| 125 |
+
bad = '{"utterance": "hello"'
|
| 126 |
+
try:
|
| 127 |
+
json.loads(bad)
|
| 128 |
+
par_mal = "UNEXPECTED: bad JSON parsed"
|
| 129 |
+
except json.JSONDecodeError:
|
| 130 |
+
err_line = None
|
| 131 |
+
try:
|
| 132 |
+
wrapper.play_turn(bad) # type: ignore[arg-type]
|
| 133 |
+
except Exception as e:
|
| 134 |
+
err_line = f"json.loads fails; play_turn(str) -> {type(e).__name__}: {e!s}"[:200]
|
| 135 |
+
par_mal = err_line or "play_turn(str) did not raise"
|
| 136 |
+
entries.append(
|
| 137 |
+
(
|
| 138 |
+
"Malformed JSON string mishandled at play_turn",
|
| 139 |
+
"FAIL" if par_mal.startswith("UNEXPECTED") else "PASS",
|
| 140 |
+
"Correct pipeline: json.loads first; " + par_mal,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# 4) Empty string
|
| 145 |
+
empty_explain = []
|
| 146 |
+
try:
|
| 147 |
+
json.loads("")
|
| 148 |
+
except json.JSONDecodeError as e0:
|
| 149 |
+
empty_explain.append(f"json.loads('') -> {e0!s}"[:100])
|
| 150 |
+
try:
|
| 151 |
+
wrapper.play_turn("")
|
| 152 |
+
except Exception as e1:
|
| 153 |
+
empty_explain.append(f"play_turn('') -> {type(e1).__name__}")
|
| 154 |
+
else:
|
| 155 |
+
empty_explain.append("play_turn('') did not raise (unexpected)")
|
| 156 |
+
entries.append(
|
| 157 |
+
(
|
| 158 |
+
"Empty string completion / action",
|
| 159 |
+
"PASS" if "did not raise" not in str(empty_explain[-1]) else "FAIL",
|
| 160 |
+
" | ".join(empty_explain),
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
_print_checks(entries)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _print_checks(rows: list[tuple[str, str, str]]) -> None:
|
| 168 |
+
print()
|
| 169 |
+
print("CHECKS")
|
| 170 |
+
for name, status, detail in rows:
|
| 171 |
+
print(f" [{status}] {name}")
|
| 172 |
+
for line in detail.split("\n"):
|
| 173 |
+
print(f" {line}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
scripts/audit_reward.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Static reward-surface audit for Parlay (read-only, no env rollouts).
|
| 4 |
+
Analytical notes derived from parlay_env/grader.py, parlay_env/reward.py, game/scenarios.py.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main() -> None:
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description="Analytical Parlay reward-hacking and alignment audit (static, no rollouts)"
|
| 16 |
+
)
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--repo-root",
|
| 19 |
+
type=Path,
|
| 20 |
+
default=None,
|
| 21 |
+
help="Project root (default: parent of scripts/)",
|
| 22 |
+
)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
|
| 26 |
+
for sub in (root / "parlay_env", root / "game"):
|
| 27 |
+
if not sub.is_dir():
|
| 28 |
+
print(f"Expected directory missing: {sub}")
|
| 29 |
+
return
|
| 30 |
+
if str(root) not in sys.path:
|
| 31 |
+
sys.path.insert(0, str(root))
|
| 32 |
+
|
| 33 |
+
from parlay_env import grader as grader_mod
|
| 34 |
+
from parlay_env import reward as reward_mod
|
| 35 |
+
from game import scenarios as scenarios_mod
|
| 36 |
+
|
| 37 |
+
# Ensure grader symbols resolve (import side effects only)
|
| 38 |
+
_ = (grader_mod.compute_step_reward, grader_mod.detect_bluff_challenge)
|
| 39 |
+
|
| 40 |
+
results: list[tuple[str, str, str]] = []
|
| 41 |
+
|
| 42 |
+
print("=" * 72)
|
| 43 |
+
print("1. NOISE TERM (THETA * noise_t)")
|
| 44 |
+
print("-" * 72)
|
| 45 |
+
print(
|
| 46 |
+
"In compute_step_reward, noise_t = 1.0 when cosine_sim(utterance, prior offer text) < 0.3, "
|
| 47 |
+
"else 0.0. The total applies -THETA*noise (penalty on low similarity, not a bonus)."
|
| 48 |
+
)
|
| 49 |
+
print(
|
| 50 |
+
"Trivial *positive* side-channel from the noise term does not exist: noise can only add "
|
| 51 |
+
"a penalty, never increase reward. Avoiding the penalty means keeping utterance "
|
| 52 |
+
"overlapping the token history of prior offers (e.g. echoing offer-like numbers), not "
|
| 53 |
+
"necessarily any arbitrary small talk (which can score low overlap and be penalized)."
|
| 54 |
+
)
|
| 55 |
+
print("NOISE TERM: Low hacking risk - the term is a unilateral penalty, not a reward. OK.")
|
| 56 |
+
results.append(("NOISE TERM (THETA*noise)", "PASS", "Penalty only; no positive exploit"))
|
| 57 |
+
|
| 58 |
+
print()
|
| 59 |
+
print("=" * 72)
|
| 60 |
+
print("2. TOM TERM (BETA * ToM)")
|
| 61 |
+
print("-" * 72)
|
| 62 |
+
print(
|
| 63 |
+
"ToM in compute_step_reward uses the latest belief in next_state.belief_history against "
|
| 64 |
+
"next_state.hidden_state. The agent's utterance does not directly author beliefs; in the "
|
| 65 |
+
"runner/server path, beliefs update from observed opponent behavior."
|
| 66 |
+
)
|
| 67 |
+
print("TOM TERM: Not hackable by agent. OK.")
|
| 68 |
+
results.append(("ToM (BETA*ToM)", "PASS", "Beliefs from observation path, not direct agent edit"))
|
| 69 |
+
|
| 70 |
+
print()
|
| 71 |
+
print("=" * 72)
|
| 72 |
+
print("3. BLUFF BONUS (PSI)")
|
| 73 |
+
print("-" * 72)
|
| 74 |
+
print("detect_bluff_challenge() is structured as: (1) if stated/true are None -> False; (2) compute")
|
| 75 |
+
print(" bluff_threshold = 15% of |true| and require |stated-true| > threshold; (3) only then check")
|
| 76 |
+
print(" skepticism phrases. There is no partial credit for phrases alone if (2) fails.")
|
| 77 |
+
print(
|
| 78 |
+
"In compute_step_reward, bluff_bonus = PSI only when: tactical_move is None, "
|
| 79 |
+
"state.hidden_state.last_stated_batna is not None, AND detect_bluff_challenge(...)=True "
|
| 80 |
+
"(which already requires the >15% gap AND a skepticism phrase)."
|
| 81 |
+
)
|
| 82 |
+
print("All conditions are ANDed; there is no independent partial PSI for skepticism only.")
|
| 83 |
+
print("BLUFF BONUS: Gated correctly. OK.")
|
| 84 |
+
results.append(("BLUFF BONUS (PSI)", "PASS", "All conjuncts required; no partial PSI"))
|
| 85 |
+
|
| 86 |
+
print()
|
| 87 |
+
print("=" * 72)
|
| 88 |
+
print("4. MEV (MU * MEV) - drift + adaptation")
|
| 89 |
+
print("-" * 72)
|
| 90 |
+
print("MEV in compute_step_reward uses drift_event or next_state.drift_event; mev_bonus = MU if a drift")
|
| 91 |
+
print("marker is present AND the utterance contains an adaptation subphrase (see grader for tokens).")
|
| 92 |
+
print("The agent does not set drift_event; game/scenarios.py defines trigger_turn per scenario.\n")
|
| 93 |
+
for sid, sc in sorted(scenarios_mod.SCENARIOS.items()):
|
| 94 |
+
if not sc.drift_events:
|
| 95 |
+
print(f" {sid}: (no drift_events)")
|
| 96 |
+
else:
|
| 97 |
+
turns = [f"turn {e.trigger_turn}: {e.event!r}" for e in sc.drift_events]
|
| 98 |
+
print(f" {sid}: {', '.join(turns)}")
|
| 99 |
+
print()
|
| 100 |
+
print("MEV TERM: Not hackable. OK.")
|
| 101 |
+
results.append(("MEV (MU*drift adapt)", "PASS", "Drift is scenario-time-gated, not agent-triggered"))
|
| 102 |
+
|
| 103 |
+
print()
|
| 104 |
+
print("=" * 72)
|
| 105 |
+
print("5. DELTA CONCESSION - offer_amount = None")
|
| 106 |
+
print("-" * 72)
|
| 107 |
+
print(
|
| 108 |
+
"In compute_step_reward: delta_v only updates when action.offer_amount is not None. "
|
| 109 |
+
"concession_t only runs when state.offer_history and action.offer_amount is not None."
|
| 110 |
+
)
|
| 111 |
+
print(
|
| 112 |
+
"If offer_amount is always None, delta_v=0 and concession_t=0, so the agent forgoes both "
|
| 113 |
+
"alpha*deltaV upside and any delta*concession penalty in those terms."
|
| 114 |
+
)
|
| 115 |
+
print(
|
| 116 |
+
"CONCESSION HACK RISK: Agent can set offer_amount=None every turn to avoid both deltaV reward "
|
| 117 |
+
"AND concession penalty. Net effect: misses upside but avoids downside. "
|
| 118 |
+
"Document as known limitation."
|
| 119 |
+
)
|
| 120 |
+
results.append(
|
| 121 |
+
(
|
| 122 |
+
"Concession (DELTA) / offer=None",
|
| 123 |
+
"WARN",
|
| 124 |
+
"offer_amount=None zeroes both deltaV and concession terms",
|
| 125 |
+
)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
print()
|
| 129 |
+
print("=" * 72)
|
| 130 |
+
print("6. TERMINAL vs STEP REWARD alignment")
|
| 131 |
+
print("-" * 72)
|
| 132 |
+
print("Step: emphasizes offer improvement (ALPHA), ToM (BETA), penalties and bonuses as shaped in grader.")
|
| 133 |
+
print(
|
| 134 |
+
"Terminal (compute_terminal_reward): deal_efficiency, speed, drift bonus; GAMMA = "
|
| 135 |
+
f"{reward_mod.GAMMA} on efficiency."
|
| 136 |
+
)
|
| 137 |
+
print(
|
| 138 |
+
"Tension: an agent can chase high per-step terms (e.g. anchoring, offer deltas) and still miss "
|
| 139 |
+
"agreement, yielding low terminal efficiency if no deal closes or final price is poor."
|
| 140 |
+
)
|
| 141 |
+
print(
|
| 142 |
+
"This is a mis-alignment by design: it pressures closing unless step weights drown the signal - "
|
| 143 |
+
"monitor in training, not a pure bug."
|
| 144 |
+
)
|
| 145 |
+
print("STEP vs TERMINAL: WARN - intentional tension; monitor in training, not a pure logic bug.")
|
| 146 |
+
results.append(
|
| 147 |
+
(
|
| 148 |
+
"Step vs terminal alignment",
|
| 149 |
+
"WARN",
|
| 150 |
+
"Dense step and terminal E can pull apart without a deal",
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
print()
|
| 155 |
+
print("=" * 72)
|
| 156 |
+
print("SUMMARY (6 checks)")
|
| 157 |
+
print("=" * 72)
|
| 158 |
+
for label, level, note in results:
|
| 159 |
+
print(f" [{level:4s}] {label} - {note}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
scripts/check_training_config.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pre-flight training configuration checklist (SFT + GRPO).
|
| 4 |
+
Read-only: inspects training/sft_train.py and training/grpo_train.py; does not start training.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import inspect
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Pre-flight SFT/GRPO training config checklist")
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--repo-root",
|
| 19 |
+
type=Path,
|
| 20 |
+
default=None,
|
| 21 |
+
help="Project root (default: parent of scripts/)",
|
| 22 |
+
)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
root = (args.repo_root or Path(__file__).resolve().parent.parent).resolve()
|
| 26 |
+
if str(root) not in sys.path:
|
| 27 |
+
sys.path.insert(0, str(root))
|
| 28 |
+
|
| 29 |
+
sft_path = root / "training" / "sft_train.py"
|
| 30 |
+
grpo_path = root / "training" / "grpo_train.py"
|
| 31 |
+
if not sft_path.is_file() or not grpo_path.is_file():
|
| 32 |
+
print(f"Missing {sft_path} or {grpo_path}")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
import training.sft_train as sft
|
| 36 |
+
import training.grpo_train as grpo
|
| 37 |
+
|
| 38 |
+
sft_text = sft_path.read_text(encoding="utf-8")
|
| 39 |
+
grpo_text = grpo_path.read_text(encoding="utf-8")
|
| 40 |
+
sft_fn = inspect.getsource(sft.train_sft)
|
| 41 |
+
|
| 42 |
+
checks: list[tuple[str, bool, str]] = []
|
| 43 |
+
|
| 44 |
+
# Base model
|
| 45 |
+
want_model = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 46 |
+
ok_model = sft.DEFAULT_MODEL == want_model
|
| 47 |
+
checks.append(("[ ] Base model: Qwen/Qwen2.5-1.5B-Instruct", ok_model, f"found {sft.DEFAULT_MODEL!r}"))
|
| 48 |
+
|
| 49 |
+
# LoRA
|
| 50 |
+
ok_lora = "r=16" in sft_fn and "lora_alpha=32" in sft_fn
|
| 51 |
+
checks.append(("[ ] SFT LoRA r=16, alpha=32", ok_lora, "in train_sft()"))
|
| 52 |
+
|
| 53 |
+
# SFT training args
|
| 54 |
+
ok_epochs = "num_train_epochs=3" in sft_text
|
| 55 |
+
ok_b = "per_device_train_batch_size=4" in sft_text
|
| 56 |
+
ok_g = "gradient_accumulation_steps=4" in sft_text
|
| 57 |
+
eff = 4 * 4
|
| 58 |
+
ok_sft = ok_epochs and ok_b and ok_g
|
| 59 |
+
checks.append(
|
| 60 |
+
(
|
| 61 |
+
f"[ ] SFT epochs=3, batch=4, grad_accum=4 (effective ~{eff})",
|
| 62 |
+
ok_sft,
|
| 63 |
+
f"epochs={ok_epochs} batch={ok_b} grad={ok_g}",
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Output dir
|
| 68 |
+
want_out = "checkpoints/sft_1.5b/"
|
| 69 |
+
ok_out = sft.DEFAULT_OUTPUT == want_out
|
| 70 |
+
checks.append(("[ ] SFT output: checkpoints/sft_1.5b/", ok_out, f"default={sft.DEFAULT_OUTPUT!r}"))
|
| 71 |
+
|
| 72 |
+
# GRPO BASE_MODEL (read at import time in grpo_train)
|
| 73 |
+
base = os.getenv("BASE_MODEL", "")
|
| 74 |
+
grpo_default = grpo.BASE_MODEL
|
| 75 |
+
if not base:
|
| 76 |
+
grpo_brief = f"BASE_MODEL env not set - will use module default {grpo_default!r}"
|
| 77 |
+
else:
|
| 78 |
+
grpo_brief = f"set to {base!r}"
|
| 79 |
+
checks.append(
|
| 80 |
+
(
|
| 81 |
+
"[ ] GRPO reads BASE_MODEL from env",
|
| 82 |
+
True,
|
| 83 |
+
grpo_brief,
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# GRPO reward weights
|
| 88 |
+
want_line = "reward_weights=[3.0, 1.5, 2.0, 0.5]"
|
| 89 |
+
in_rw = want_line in grpo_text
|
| 90 |
+
w_line = next((ln.strip() for ln in grpo_text.splitlines() if "reward_weights" in ln), "")
|
| 91 |
+
checks.append(
|
| 92 |
+
(
|
| 93 |
+
"[ ] GRPO reward weights [efficiency, tom, anti-cap, format] = [3.0, 1.5, 2.0, 0.5]",
|
| 94 |
+
in_rw,
|
| 95 |
+
w_line or "not found",
|
| 96 |
+
)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# GRPO data path
|
| 100 |
+
d_ok = 'default="data/episodes.jsonl"' in grpo_text
|
| 101 |
+
checks.append(
|
| 102 |
+
(
|
| 103 |
+
'[ ] GRPO --data default: data/episodes.jsonl',
|
| 104 |
+
d_ok,
|
| 105 |
+
"see grpo_train.main argparse" if d_ok else "check grpo_train.py",
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
checks.append(
|
| 110 |
+
(
|
| 111 |
+
"[ ] Estimated VRAM note (1.5B + LoRA r=16 ~6-8GB SFT; more for GRPO)",
|
| 112 |
+
True,
|
| 113 |
+
"informational (not a failure if you skip the box)",
|
| 114 |
+
)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
hf = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 118 |
+
ok_hf = bool(hf)
|
| 119 |
+
checks.append(
|
| 120 |
+
(
|
| 121 |
+
"[ ] HF token for push (HF_TOKEN or HUGGING_FACE_HUB_TOKEN)",
|
| 122 |
+
ok_hf,
|
| 123 |
+
"set" if ok_hf else "not set - needed to push checkpoints",
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
print("Training config pre-flight (read from training/sft_train.py, training/grpo_train.py)\n")
|
| 128 |
+
for line, ok, note in checks:
|
| 129 |
+
mark = "x" if ok else " "
|
| 130 |
+
display = line.replace("[ ]", f"[{mark}]", 1) if line.startswith("[ ]") else line
|
| 131 |
+
print(display)
|
| 132 |
+
if note:
|
| 133 |
+
print(f" -> {note}")
|
| 134 |
+
print()
|
| 135 |
+
|
| 136 |
+
core_ok = ok_model and ok_lora and ok_sft and ok_out and in_rw and d_ok
|
| 137 |
+
if core_ok and ok_hf:
|
| 138 |
+
print("\nREADY FOR TRAINING (SFT + GRPO config strings match; HF token present for hub).")
|
| 139 |
+
elif core_ok:
|
| 140 |
+
print(
|
| 141 |
+
"\nMOSTLY READY: fix missing HF token if you need push_to_hub; verify BASE_MODEL for GRPO stage."
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
print("\nNEEDS FIXING: see failed [ ] items above (model path, LoRA, SFT args, or output dir).")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
scripts/eval_comparison.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compare baseline, Gemini, and GRPO JSON summaries."""
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _load(path: str) -> dict:
|
| 8 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 9 |
+
return json.load(f)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _fmt_pct(value: float) -> str:
|
| 13 |
+
return f"{100.0 * value:.1f}%"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _row(label: str, data: dict) -> str:
|
| 17 |
+
return (
|
| 18 |
+
f"| {label} | {data.get('avg_reward', 0):.3f} | "
|
| 19 |
+
f"{_fmt_pct(float(data.get('deal_rate', 0)))} | "
|
| 20 |
+
f"{float(data.get('avg_efficiency', 0)):.3f} | "
|
| 21 |
+
f"{float(data.get('avg_tom_accuracy', 0)):.3f} | "
|
| 22 |
+
f"{int(data.get('bluffs_caught', 0))} |"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _save_chart(baseline: dict, gemini: dict, grpo: dict, output_path: Path) -> None:
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
|
| 29 |
+
labels = ["avg_reward", "deal_rate", "avg_efficiency", "avg_tom_accuracy"]
|
| 30 |
+
names = ["Random", "Gemini", "GRPO"]
|
| 31 |
+
series = [baseline, gemini, grpo]
|
| 32 |
+
|
| 33 |
+
x = range(len(labels))
|
| 34 |
+
width = 0.22
|
| 35 |
+
|
| 36 |
+
plt.figure(figsize=(10, 5))
|
| 37 |
+
for idx, name in enumerate(names):
|
| 38 |
+
vals = [float(series[idx].get(k, 0.0)) for k in labels]
|
| 39 |
+
plt.bar([p + (idx - 1) * width for p in x], vals, width=width, label=name)
|
| 40 |
+
|
| 41 |
+
plt.xticks(list(x), labels)
|
| 42 |
+
plt.ylabel("Metric value")
|
| 43 |
+
plt.title("Parlay Baseline vs Gemini vs GRPO")
|
| 44 |
+
plt.legend()
|
| 45 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
plt.tight_layout()
|
| 47 |
+
plt.savefig(output_path, dpi=150)
|
| 48 |
+
plt.close()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main() -> None:
|
| 52 |
+
parser = argparse.ArgumentParser(description="Compare evaluation result JSON files")
|
| 53 |
+
parser.add_argument("--baseline-results", required=True)
|
| 54 |
+
parser.add_argument("--gemini-results", required=True)
|
| 55 |
+
parser.add_argument("--grpo-results", required=True)
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
baseline = _load(args.baseline_results)
|
| 59 |
+
gemini = _load(args.gemini_results)
|
| 60 |
+
grpo = _load(args.grpo_results)
|
| 61 |
+
|
| 62 |
+
lines = [
|
| 63 |
+
"| Model | avg_reward | deal_rate | avg_efficiency | avg_tom_accuracy | bluffs_caught |",
|
| 64 |
+
"|---|---:|---:|---:|---:|---:|",
|
| 65 |
+
_row("Random baseline", baseline),
|
| 66 |
+
_row("Gemini baseline", gemini),
|
| 67 |
+
_row("GRPO", grpo),
|
| 68 |
+
]
|
| 69 |
+
table = "\n".join(lines)
|
| 70 |
+
print(table)
|
| 71 |
+
|
| 72 |
+
chart_path = Path("results/comparison.png")
|
| 73 |
+
_save_chart(baseline, gemini, grpo, chart_path)
|
| 74 |
+
print(f"\nSaved chart: {chart_path.resolve()}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
scripts/inspect_data.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pre-training data quality inspector for Parlay JSONL episode files.
|
| 4 |
+
Read-only: loads JSONL and prints statistics and RED FLAGS.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import statistics
|
| 12 |
+
from collections import Counter, defaultdict
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _safe_float(x: Any, default: float = 0.0) -> float:
|
| 18 |
+
try:
|
| 19 |
+
return float(x)
|
| 20 |
+
except (TypeError, ValueError):
|
| 21 |
+
return default
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _percentile_sorted(sorted_vals: list[float], p: float) -> float:
|
| 25 |
+
if not sorted_vals:
|
| 26 |
+
return 0.0
|
| 27 |
+
k = (len(sorted_vals) - 1) * p / 100.0
|
| 28 |
+
f = math.floor(k)
|
| 29 |
+
c = math.ceil(k)
|
| 30 |
+
if f == c:
|
| 31 |
+
return sorted_vals[int(k)]
|
| 32 |
+
return sorted_vals[f] * (c - k) + sorted_vals[c] * (k - f)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _outcome_bucket(rec: dict[str, Any]) -> str:
|
| 36 |
+
tr = (rec.get("termination_reason") or "") or ""
|
| 37 |
+
tr_l = tr.lower()
|
| 38 |
+
if rec.get("deal_reached") is True or tr_l in ("deal_reached", "deal", "agreement"):
|
| 39 |
+
return "deal"
|
| 40 |
+
if "zopa_collapsed" in tr_l or tr_l == "zopa_collapsed":
|
| 41 |
+
return "zopa_collapsed"
|
| 42 |
+
if "walk" in tr_l or tr_l in ("walk_away", "walkaway"):
|
| 43 |
+
return "walk_away"
|
| 44 |
+
if tr_l in ("max_turns",) or (rec.get("deal_reached") is False and "max" in tr_l):
|
| 45 |
+
return "max_turns"
|
| 46 |
+
if tr_l:
|
| 47 |
+
return f"other:{tr_l}"
|
| 48 |
+
if rec.get("deal_reached") is False and rec.get("final_price") is None:
|
| 49 |
+
return "no_deal_or_unknown"
|
| 50 |
+
return "unknown"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _utterance_lengths(conversation: Any) -> list[int]:
|
| 54 |
+
if not isinstance(conversation, list):
|
| 55 |
+
return []
|
| 56 |
+
out: list[int] = []
|
| 57 |
+
for turn in conversation:
|
| 58 |
+
if not isinstance(turn, dict):
|
| 59 |
+
continue
|
| 60 |
+
content = turn.get("content", "")
|
| 61 |
+
if isinstance(content, str) and content.strip():
|
| 62 |
+
out.append(len(content))
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def main() -> None:
|
| 67 |
+
parser = argparse.ArgumentParser(description="Inspect Parlay episode JSONL data quality")
|
| 68 |
+
parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL file")
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
|
| 71 |
+
path = Path(args.data)
|
| 72 |
+
if not path.is_file():
|
| 73 |
+
print(f"File not found: {path.resolve()}")
|
| 74 |
+
print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
records: list[dict[str, Any]] = []
|
| 78 |
+
with path.open("r", encoding="utf-8") as f:
|
| 79 |
+
for line_no, line in enumerate(f, 1):
|
| 80 |
+
line = line.strip()
|
| 81 |
+
if not line:
|
| 82 |
+
continue
|
| 83 |
+
try:
|
| 84 |
+
records.append(json.loads(line))
|
| 85 |
+
except json.JSONDecodeError as e:
|
| 86 |
+
print(f"[WARN] Skipping line {line_no}: invalid JSON ({e})")
|
| 87 |
+
|
| 88 |
+
n = len(records)
|
| 89 |
+
print(f"=== Parlay data inspector: {path} ===")
|
| 90 |
+
print(f"Total episode records: {n}\n")
|
| 91 |
+
|
| 92 |
+
if n == 0:
|
| 93 |
+
print("No records to analyze.")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
# SCHEMA
|
| 97 |
+
missing_prompt = sum(1 for r in records if not str(r.get("prompt", "")).strip())
|
| 98 |
+
missing_scenario = sum(1 for r in records if not str(r.get("scenario_id", "")).strip())
|
| 99 |
+
missing_persona = sum(1 for r in records if not str(r.get("persona", "")).strip())
|
| 100 |
+
missing_metadata = sum(1 for r in records if "metadata" not in r)
|
| 101 |
+
print("SCHEMA")
|
| 102 |
+
print(f" prompt present: {n - missing_prompt}/{n}")
|
| 103 |
+
print(f" scenario_id present: {n - missing_scenario}/{n}")
|
| 104 |
+
print(f" persona present: {n - missing_persona}/{n}")
|
| 105 |
+
print(f" metadata key present: {n - missing_metadata}/{n} (audit checklist; generate_data may omit)")
|
| 106 |
+
print()
|
| 107 |
+
|
| 108 |
+
def cum_reward(r: dict[str, Any]) -> float:
|
| 109 |
+
if "cumulative_reward" in r:
|
| 110 |
+
return _safe_float(r.get("cumulative_reward"))
|
| 111 |
+
return _safe_float(r.get("reward"))
|
| 112 |
+
|
| 113 |
+
rews = [cum_reward(r) for r in records]
|
| 114 |
+
rews_sorted = sorted(rews)
|
| 115 |
+
|
| 116 |
+
print("REWARD (total / cumulative - field 'reward' or 'cumulative_reward')")
|
| 117 |
+
print(f" min: {min(rews):.4f}")
|
| 118 |
+
print(f" max: {max(rews):.4f}")
|
| 119 |
+
print(f" mean: {statistics.mean(rews):.4f}")
|
| 120 |
+
print(f" std: {statistics.stdev(rews) if len(rews) > 1 else 0.0:.4f}")
|
| 121 |
+
print(f" p10: {_percentile_sorted(rews_sorted, 10):.4f}")
|
| 122 |
+
print(f" p90: {_percentile_sorted(rews_sorted, 90):.4f}")
|
| 123 |
+
print()
|
| 124 |
+
|
| 125 |
+
outcomes = [_outcome_bucket(r) for r in records]
|
| 126 |
+
oc = Counter(outcomes)
|
| 127 |
+
print("EPISODE OUTCOMES (best-effort from termination_reason + deal_reached)")
|
| 128 |
+
for k, v in sorted(oc.items(), key=lambda x: -x[1]):
|
| 129 |
+
print(f" {k}: {v} ({100.0 * v / n:.1f}%)")
|
| 130 |
+
print()
|
| 131 |
+
|
| 132 |
+
effs = [_safe_float(r.get("deal_efficiency"), 0.0) for r in records]
|
| 133 |
+
toms = []
|
| 134 |
+
for r in records:
|
| 135 |
+
t = r.get("tom_accuracy_avg", r.get("tom_accuracy"))
|
| 136 |
+
toms.append(_safe_float(t, 0.0))
|
| 137 |
+
|
| 138 |
+
print("EFFICIENCY (deal_efficiency, 0-1)")
|
| 139 |
+
if effs:
|
| 140 |
+
print(f" mean: {statistics.mean(effs):.4f} min: {min(effs):.4f} max: {max(effs):.4f}")
|
| 141 |
+
print()
|
| 142 |
+
|
| 143 |
+
print("TOM (tom_accuracy_avg or tom_accuracy)")
|
| 144 |
+
if toms:
|
| 145 |
+
print(f" mean: {statistics.mean(toms):.4f} min: {min(toms):.4f} max: {max(toms):.4f}")
|
| 146 |
+
print()
|
| 147 |
+
|
| 148 |
+
all_lens: list[int] = []
|
| 149 |
+
degenerate_turns = 0
|
| 150 |
+
total_turns = 0
|
| 151 |
+
for r in records:
|
| 152 |
+
lens = _utterance_lengths(r.get("conversation"))
|
| 153 |
+
all_lens.extend(lens)
|
| 154 |
+
for L in lens:
|
| 155 |
+
total_turns += 1
|
| 156 |
+
if L < 10:
|
| 157 |
+
degenerate_turns += 1
|
| 158 |
+
|
| 159 |
+
print("UTTERANCE LENGTH (conversation[*].content)")
|
| 160 |
+
if all_lens:
|
| 161 |
+
print(f" mean chars/turn: {statistics.mean(all_lens):.1f}")
|
| 162 |
+
print(f" turns < 10 chars: {degenerate_turns}/{total_turns} ({100.0 * degenerate_turns / max(1, total_turns):.1f}%)")
|
| 163 |
+
else:
|
| 164 |
+
print(" (no conversation utterances found)")
|
| 165 |
+
print()
|
| 166 |
+
|
| 167 |
+
bluff_pos = sum(1 for r in records if int(r.get("bluffs_caught", 0) or 0) > 0)
|
| 168 |
+
drift_yes = sum(1 for r in records if r.get("drift_adapted") is True)
|
| 169 |
+
|
| 170 |
+
print("BLUFF RATE: episodes with bluffs_caught > 0")
|
| 171 |
+
print(f" {bluff_pos}/{n} ({100.0 * bluff_pos / n:.1f}%) (field may be missing in JSONL -> counted as 0)")
|
| 172 |
+
print()
|
| 173 |
+
print("DRIFT ADAPTATION: drift_adapted == True")
|
| 174 |
+
print(f" {drift_yes}/{n} ({100.0 * drift_yes / n:.1f}%)")
|
| 175 |
+
print()
|
| 176 |
+
|
| 177 |
+
by_persona: dict[str, list[dict]] = defaultdict(list)
|
| 178 |
+
by_scenario: dict[str, list[dict]] = defaultdict(list)
|
| 179 |
+
for r in records:
|
| 180 |
+
p = str(r.get("persona", "??"))
|
| 181 |
+
s = str(r.get("scenario_id", "??"))
|
| 182 |
+
by_persona[p].append(r)
|
| 183 |
+
by_scenario[s].append(r)
|
| 184 |
+
|
| 185 |
+
print("=== PER-PERSONA ===")
|
| 186 |
+
for p in sorted(by_persona.keys()):
|
| 187 |
+
grp = by_persona[p]
|
| 188 |
+
m = len(grp)
|
| 189 |
+
pr = [cum_reward(x) for x in grp]
|
| 190 |
+
pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp]
|
| 191 |
+
pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp]
|
| 192 |
+
po = [_outcome_bucket(x) for x in grp]
|
| 193 |
+
dr = sum(1 for f in po if f == "deal") / m
|
| 194 |
+
print(
|
| 195 |
+
f" {p}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} "
|
| 196 |
+
f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}"
|
| 197 |
+
)
|
| 198 |
+
print()
|
| 199 |
+
|
| 200 |
+
print("=== PER-SCENARIO ===")
|
| 201 |
+
for s in sorted(by_scenario.keys()):
|
| 202 |
+
grp = by_scenario[s]
|
| 203 |
+
m = len(grp)
|
| 204 |
+
pr = [cum_reward(x) for x in grp]
|
| 205 |
+
pe = [_safe_float(x.get("deal_efficiency"), 0) for x in grp]
|
| 206 |
+
pt = [_safe_float(x.get("tom_accuracy_avg", x.get("tom_accuracy")), 0) for x in grp]
|
| 207 |
+
po = [_outcome_bucket(x) for x in grp]
|
| 208 |
+
dr = sum(1 for f in po if f == "deal") / m
|
| 209 |
+
print(
|
| 210 |
+
f" {s}: n={m} mean_reward={statistics.mean(pr) if pr else 0:.2f} "
|
| 211 |
+
f"mean_eff={statistics.mean(pe) if pe else 0:.3f} mean_tom={statistics.mean(pt) if pt else 0:.3f} deal_rate={dr:.2%}"
|
| 212 |
+
)
|
| 213 |
+
print()
|
| 214 |
+
|
| 215 |
+
# RED FLAGS
|
| 216 |
+
print("=== RED FLAGS ===")
|
| 217 |
+
flags: list[str] = []
|
| 218 |
+
|
| 219 |
+
bad_rew = sum(1 for x in rews if x < -50) / n
|
| 220 |
+
if bad_rew > 0.30:
|
| 221 |
+
flags.append(f"> 30% episodes with total reward < -50 ({100 * bad_rew:.1f}%)")
|
| 222 |
+
|
| 223 |
+
max_turns_rate = sum(1 for o in outcomes if o == "max_turns") / n
|
| 224 |
+
if max_turns_rate > 0.40:
|
| 225 |
+
flags.append(f"> 40% ending in max_turns ({100 * max_turns_rate:.1f}%)")
|
| 226 |
+
|
| 227 |
+
drift_rate = drift_yes / n
|
| 228 |
+
if drift_rate < 0.10:
|
| 229 |
+
flags.append(f"< 10% drift_adapted ({100 * drift_rate:.1f}%)")
|
| 230 |
+
|
| 231 |
+
for p, grp in by_persona.items():
|
| 232 |
+
po = [_outcome_bucket(x) for x in grp]
|
| 233 |
+
dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0
|
| 234 |
+
if dr == 0.0 and len(grp) >= 3:
|
| 235 |
+
flags.append(f"Persona {p!r} has 0% deal rate (n={len(grp)})")
|
| 236 |
+
|
| 237 |
+
for s, grp in by_scenario.items():
|
| 238 |
+
po = [_outcome_bucket(x) for x in grp]
|
| 239 |
+
dr = sum(1 for f in po if f == "deal") / len(grp) if grp else 0.0
|
| 240 |
+
if dr == 0.0 and len(grp) >= 3:
|
| 241 |
+
flags.append(f"Scenario {s!r} has 0% deal rate (n={len(grp)})")
|
| 242 |
+
|
| 243 |
+
if all_lens and statistics.mean(all_lens) < 20.0:
|
| 244 |
+
flags.append(f"Mean utterance length {statistics.mean(all_lens):.1f} chars < 20 (possibly degenerate)")
|
| 245 |
+
|
| 246 |
+
if max(rews) > 400:
|
| 247 |
+
flags.append(f"At least one episode with total reward > 400 (max={max(rews):.2f}) - check for scale bugs or rare combo")
|
| 248 |
+
|
| 249 |
+
if missing_metadata == n:
|
| 250 |
+
flags.append("No record has a top-level 'metadata' key (optional for training; audit asked for it)")
|
| 251 |
+
|
| 252 |
+
if not flags:
|
| 253 |
+
print(" (none triggered)")
|
| 254 |
+
else:
|
| 255 |
+
for f in flags:
|
| 256 |
+
print(f" * {f}")
|
| 257 |
+
print()
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
if __name__ == "__main__":
|
| 261 |
+
main()
|
scripts/push_docker.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Usage: ./scripts/push_docker.sh <dockerhub-username> <tag>
|
| 3 |
+
USERNAME=${1:-yourusername}
|
| 4 |
+
TAG=${2:-latest}
|
| 5 |
+
docker build -t $USERNAME/parlay:$TAG .
|
| 6 |
+
docker push $USERNAME/parlay:$TAG
|
| 7 |
+
echo "Pushed $USERNAME/parlay:$TAG"
|
| 8 |
+
echo "For HF Spaces: set Dockerfile app_port to 7860 and push repo to huggingface.co/spaces/$USERNAME/parlay"
|
scripts/run_gemini_baseline.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run Gemini self-play baseline and save summary JSON."""
|
| 2 |
+
import argparse
|
| 3 |
+
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from agent.runner import run_episode
|
| 9 |
+
from parlay_env.models import PersonaType
|
| 10 |
+
|
| 11 |
+
PERSONAS = [PersonaType.SHARK, PersonaType.DIPLOMAT, PersonaType.VETERAN]
|
| 12 |
+
SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _mean(values: list[float]) -> float:
|
| 16 |
+
return sum(values) / len(values) if values else 0.0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
async def _run(episodes: int) -> list[dict]:
|
| 20 |
+
rows: list[dict] = []
|
| 21 |
+
for i in range(episodes):
|
| 22 |
+
persona = PERSONAS[i % len(PERSONAS)]
|
| 23 |
+
scenario_id = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
|
| 24 |
+
result = await run_episode(
|
| 25 |
+
persona=persona,
|
| 26 |
+
scenario_id=scenario_id,
|
| 27 |
+
inject_noise=False,
|
| 28 |
+
force_drift=True,
|
| 29 |
+
seed=i + 100,
|
| 30 |
+
max_turns=20,
|
| 31 |
+
)
|
| 32 |
+
rows.append(
|
| 33 |
+
{
|
| 34 |
+
"avg_reward": float(result.grade.total_reward),
|
| 35 |
+
"deal_rate": 1.0 if result.final_price is not None else 0.0,
|
| 36 |
+
"avg_efficiency": float(result.grade.deal_efficiency),
|
| 37 |
+
"avg_tom_accuracy": float(result.grade.tom_accuracy_avg),
|
| 38 |
+
"bluffs_caught": int(result.grade.bluffs_caught),
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
return rows
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _summarise(rows: list[dict], episodes_requested: int) -> dict:
|
| 45 |
+
return {
|
| 46 |
+
"episodes_requested": episodes_requested,
|
| 47 |
+
"episodes_completed": len(rows),
|
| 48 |
+
"avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
|
| 49 |
+
"deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
|
| 50 |
+
"avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
|
| 51 |
+
"avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
|
| 52 |
+
"bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main() -> None:
|
| 57 |
+
parser = argparse.ArgumentParser(description="Run Gemini self-play baseline")
|
| 58 |
+
parser.add_argument("--episodes", type=int, default=20)
|
| 59 |
+
parser.add_argument("--output", default="results/gemini_baseline.json")
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 63 |
+
rows = asyncio.run(_run(args.episodes))
|
| 64 |
+
summary = _summarise(rows, args.episodes)
|
| 65 |
+
|
| 66 |
+
out = Path(args.output)
|
| 67 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 68 |
+
out.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 69 |
+
print(json.dumps(summary, indent=2))
|
| 70 |
+
print(f"\nSaved Gemini baseline to {out.resolve()}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
scripts/validate_sft_data.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Validate JSONL rows against what training/sft_train.py will actually use.
|
| 4 |
+
Read-only; does not run training.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import statistics
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Mirror training/sft_train._extract_completions (do not import to avoid heavy deps at import time)
|
| 15 |
+
def _extract_completions(rec: dict) -> list[str]:
|
| 16 |
+
completion = rec.get("completion")
|
| 17 |
+
if isinstance(completion, str) and completion.strip():
|
| 18 |
+
return [completion.strip()]
|
| 19 |
+
|
| 20 |
+
conversation = rec.get("conversation", [])
|
| 21 |
+
candidates: list[str] = []
|
| 22 |
+
if isinstance(conversation, list):
|
| 23 |
+
for turn in conversation:
|
| 24 |
+
if not isinstance(turn, dict):
|
| 25 |
+
continue
|
| 26 |
+
role = str(turn.get("role", "")).lower()
|
| 27 |
+
content = str(turn.get("content", "")).strip()
|
| 28 |
+
if role == "negotiator" and content:
|
| 29 |
+
candidates.append(content)
|
| 30 |
+
return candidates
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _approx_tokens(text: str) -> float:
|
| 34 |
+
"""Rough token estimate without tokenizer (good enough for preflight OOM risk)."""
|
| 35 |
+
if not text:
|
| 36 |
+
return 0.0
|
| 37 |
+
return len(text) / 4.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main() -> None:
|
| 41 |
+
parser = argparse.ArgumentParser(description="Validate SFT JSONL against sft_train.py expectations")
|
| 42 |
+
parser.add_argument("--data", type=str, default="data/episodes.jsonl", help="Path to JSONL file")
|
| 43 |
+
args = parser.parse_args()
|
| 44 |
+
|
| 45 |
+
path = Path(args.data)
|
| 46 |
+
if not path.is_file():
|
| 47 |
+
print(f"File not found: {path.resolve()}")
|
| 48 |
+
print("Run: python -m training.generate_data --episodes 80 --output data/episodes.jsonl")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
usable_rows = 0
|
| 52 |
+
skipped = 0
|
| 53 |
+
prompt_tok: list[float] = []
|
| 54 |
+
completion_tok: list[float] = []
|
| 55 |
+
first_bad_line: int | None = None
|
| 56 |
+
|
| 57 |
+
with path.open("r", encoding="utf-8") as f:
|
| 58 |
+
for line_no, line in enumerate(f, 1):
|
| 59 |
+
line = line.strip()
|
| 60 |
+
if not line:
|
| 61 |
+
continue
|
| 62 |
+
try:
|
| 63 |
+
rec = json.loads(line)
|
| 64 |
+
except json.JSONDecodeError:
|
| 65 |
+
skipped += 1
|
| 66 |
+
if first_bad_line is None:
|
| 67 |
+
first_bad_line = line_no
|
| 68 |
+
continue
|
| 69 |
+
if not isinstance(rec, dict):
|
| 70 |
+
skipped += 1
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
prompt = str(rec.get("prompt", "")).strip()
|
| 74 |
+
if not prompt:
|
| 75 |
+
skipped += 1
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
completions = _extract_completions(rec)
|
| 79 |
+
if not completions:
|
| 80 |
+
skipped += 1
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
usable_rows += 1
|
| 84 |
+
p_t = _approx_tokens(prompt)
|
| 85 |
+
prompt_tok.append(p_t)
|
| 86 |
+
for c in completions:
|
| 87 |
+
completion_tok.append(_approx_tokens(c))
|
| 88 |
+
|
| 89 |
+
sft_trains_one_row_per_completion = "sft_train.py expands one dataset row per negotiator line"
|
| 90 |
+
print("SFT data validator (vs training/sft_train.py load_sft_dataset / _extract_completions)")
|
| 91 |
+
print(f" File: {path.resolve()}")
|
| 92 |
+
print(f" Note: {sft_trains_one_row_per_completion} when 'completion' is absent.")
|
| 93 |
+
print()
|
| 94 |
+
print(f" JSONL records usable (has non-empty 'prompt' and completion or negotiator text): {usable_rows}")
|
| 95 |
+
print(f" Records / rows SKIPPED: {skipped}")
|
| 96 |
+
if first_bad_line is not None:
|
| 97 |
+
print(f" (includes malformed JSONL starting around line {first_bad_line} if any)")
|
| 98 |
+
print()
|
| 99 |
+
|
| 100 |
+
if not prompt_tok and not completion_tok:
|
| 101 |
+
print("No prompt/completion lengths to summarize (all skipped).")
|
| 102 |
+
else:
|
| 103 |
+
def _summary(vals: list[float], label: str) -> None:
|
| 104 |
+
if not vals:
|
| 105 |
+
print(f" {label}: (empty)")
|
| 106 |
+
return
|
| 107 |
+
print(
|
| 108 |
+
f" {label} (approx. tokens, len/4): "
|
| 109 |
+
f"min={min(vals):.1f} max={max(vals):.1f} mean={statistics.mean(vals):.1f} "
|
| 110 |
+
f"std={(statistics.pstdev(vals) if len(vals) > 1 else 0.0):.1f}"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
_summary(prompt_tok, "Prompt length")
|
| 114 |
+
_summary(completion_tok, "Completion length (each negotiator / completion string)")
|
| 115 |
+
if prompt_tok and statistics.mean(prompt_tok) > 2048.0:
|
| 116 |
+
print(
|
| 117 |
+
" FLAG: Mean prompt length > 2048 (approx. tokens) - may OOM or truncate with "
|
| 118 |
+
"SFTConfig max_seq_length=2048 in sft_train.py on small GPUs."
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
print()
|
| 122 |
+
print(f" {usable_rows} records usable for SFT, {skipped} will be skipped (at record level; negotiator")
|
| 123 |
+
print(" expansion in sft_train can still multiply rows for usable records).")
|
| 124 |
+
if usable_rows < 50:
|
| 125 |
+
print(" WARNING: May be insufficient for SFT. Generate more data first.")
|
| 126 |
+
|
| 127 |
+
if usable_rows == 0:
|
| 128 |
+
sys.exit(1)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
tests/test_keyless.py
CHANGED
|
@@ -26,7 +26,7 @@ from parlay_env.game_theory import (
|
|
| 26 |
)
|
| 27 |
from parlay_env.grader import compute_step_reward, compute_terminal_reward
|
| 28 |
from parlay_env.grader import detect_bluff_challenge
|
| 29 |
-
from parlay_env.reward import OMEGA, PSI
|
| 30 |
from parlay_env.models import (
|
| 31 |
BeliefState, HiddenState, ParlayAction, ParlayState, PersonaType,
|
| 32 |
)
|
|
@@ -200,6 +200,25 @@ class TestGrader:
|
|
| 200 |
assert caught is True, f"Expected True, got {caught}"
|
| 201 |
assert reward >= PSI, f"Expected at least {PSI}, got {reward}"
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
def test_zopa_collapse_walkaway_keyless(self):
|
| 204 |
"""Repeated high tension collapses the ZOPA and forces walk-away."""
|
| 205 |
hidden = HiddenState(
|
|
|
|
| 26 |
)
|
| 27 |
from parlay_env.grader import compute_step_reward, compute_terminal_reward
|
| 28 |
from parlay_env.grader import detect_bluff_challenge
|
| 29 |
+
from parlay_env.reward import MU, OMEGA, PSI
|
| 30 |
from parlay_env.models import (
|
| 31 |
BeliefState, HiddenState, ParlayAction, ParlayState, PersonaType,
|
| 32 |
)
|
|
|
|
| 200 |
assert caught is True, f"Expected True, got {caught}"
|
| 201 |
assert reward >= PSI, f"Expected at least {PSI}, got {reward}"
|
| 202 |
|
| 203 |
+
def test_mev_bonus_requires_drift_event(self, parlay_state):
|
| 204 |
+
"""MEV bonus only activates when a drift event marker is present."""
|
| 205 |
+
action = ParlayAction(
|
| 206 |
+
utterance="Given that the market shifted, I can adjust.",
|
| 207 |
+
tactical_move=None,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
no_drift_state = ParlayState(**{**parlay_state.model_dump(), "step_count": 1})
|
| 211 |
+
reward_no_drift = compute_step_reward(parlay_state, action, no_drift_state)
|
| 212 |
+
|
| 213 |
+
with_drift_state = ParlayState(**{**parlay_state.model_dump(), "step_count": 1})
|
| 214 |
+
with_drift_state.__dict__["drift_event"] = "Competitor drops price 15%"
|
| 215 |
+
reward_with_drift = compute_step_reward(parlay_state, action, with_drift_state)
|
| 216 |
+
|
| 217 |
+
assert reward_with_drift >= reward_no_drift + MU, (
|
| 218 |
+
f"Expected drift reward boost >= {MU}, got "
|
| 219 |
+
f"{reward_with_drift - reward_no_drift}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
def test_zopa_collapse_walkaway_keyless(self):
|
| 223 |
"""Repeated high tension collapses the ZOPA and forces walk-away."""
|
| 224 |
hidden = HiddenState(
|
training/generate_data.py
CHANGED
|
@@ -33,6 +33,32 @@ REQUIRED_COMBINATIONS = [
|
|
| 33 |
for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
|
| 34 |
]
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def is_quality_episode(grade, args) -> tuple[bool, str]:
|
| 38 |
"""
|
|
@@ -308,7 +334,8 @@ async def run_inspect_mode(args) -> None:
|
|
| 308 |
|
| 309 |
async def run_diversity_pass(args, output_path: Path) -> None:
|
| 310 |
"""
|
| 311 |
-
Generate a quality-filtered dataset
|
|
|
|
| 312 |
"""
|
| 313 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 314 |
coverage: dict[tuple[str, str], int] = defaultdict(int)
|
|
@@ -316,169 +343,111 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 316 |
kept_records: list[dict] = []
|
| 317 |
generated = 0
|
| 318 |
discarded = 0
|
|
|
|
| 319 |
seed = 0
|
| 320 |
-
min_per_combo = max(2, args.episodes // len(REQUIRED_COMBINATIONS))
|
| 321 |
total_live_calls: int = 0
|
| 322 |
total_fallback_calls: int = 0
|
| 323 |
_verbose = not getattr(args, "quiet", False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
with open(output_path, "w", encoding="utf-8") as out_f:
|
| 326 |
while len(kept_records) < args.episodes:
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
keep, reason = is_quality_episode(
|
| 344 |
-
_grade_proxy_from_record(record),
|
| 345 |
-
args,
|
| 346 |
-
)
|
| 347 |
-
if not keep:
|
| 348 |
-
discarded += 1
|
| 349 |
-
_live_d, _fall_d = get_and_reset_counts()
|
| 350 |
-
total_live_calls += _live_d
|
| 351 |
-
total_fallback_calls += _fall_d
|
| 352 |
-
if _verbose:
|
| 353 |
-
print(
|
| 354 |
-
f"[EP --/{args.episodes:02d}] "
|
| 355 |
-
f"{persona}×{scenario_id:<27s} | "
|
| 356 |
-
f"reward={record.get('reward', 0.0):+.2f} | "
|
| 357 |
-
f"eff={record.get('deal_efficiency', 0.0):.3f} | "
|
| 358 |
-
f"kept=NO | "
|
| 359 |
-
f"total_kept={len(kept_records)}/{generated} | "
|
| 360 |
-
f"gemini_live={_live_d} fallback={_fall_d}",
|
| 361 |
-
file=sys.stderr,
|
| 362 |
-
)
|
| 363 |
-
continue
|
| 364 |
-
|
| 365 |
-
out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 366 |
-
kept_records.append(record)
|
| 367 |
-
_live, _fall = get_and_reset_counts()
|
| 368 |
-
total_live_calls += _live
|
| 369 |
-
total_fallback_calls += _fall
|
| 370 |
-
_ep_num = len(kept_records)
|
| 371 |
if _verbose:
|
| 372 |
-
_reward = record.get("reward", 0.0)
|
| 373 |
-
_eff = record.get("deal_efficiency", 0.0)
|
| 374 |
-
_combo = f"{record['persona']}×{record['scenario_id']}"
|
| 375 |
print(
|
| 376 |
-
f"[
|
| 377 |
-
f"{
|
| 378 |
-
f"reward={_reward:+.2f} | "
|
| 379 |
-
f"eff={_eff:.3f} | "
|
| 380 |
-
f"kept=YES | "
|
| 381 |
-
f"total_kept={_ep_num}/{generated} | "
|
| 382 |
-
f"gemini_live={_live} fallback={_fall}",
|
| 383 |
file=sys.stderr,
|
| 384 |
)
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
)
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
_live, _fall = get_and_reset_counts()
|
| 426 |
-
total_live_calls += _live
|
| 427 |
-
total_fallback_calls += _fall
|
| 428 |
-
_ep_num = len(kept_records)
|
| 429 |
-
if _verbose:
|
| 430 |
-
_reward = record.get("reward", 0.0)
|
| 431 |
-
_eff = record.get("deal_efficiency", 0.0)
|
| 432 |
-
_combo = f"{record['persona']}×{record['scenario_id']}"
|
| 433 |
-
print(
|
| 434 |
-
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 435 |
-
f"{_combo:<35s} | "
|
| 436 |
-
f"reward={_reward:+.2f} | "
|
| 437 |
-
f"eff={_eff:.3f} | "
|
| 438 |
-
f"kept=YES | "
|
| 439 |
-
f"total_kept={_ep_num}/{generated} | "
|
| 440 |
-
f"gemini_live={_live} fallback={_fall}",
|
| 441 |
-
file=sys.stderr,
|
| 442 |
-
)
|
| 443 |
-
if _ep_num in (20, 40, 60):
|
| 444 |
-
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 445 |
-
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 446 |
-
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 447 |
-
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 448 |
-
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 449 |
-
print(
|
| 450 |
-
f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
|
| 451 |
-
file=sys.stderr,
|
| 452 |
-
)
|
| 453 |
-
print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
|
| 454 |
-
print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
|
| 455 |
-
print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
|
| 456 |
-
print(f" Live calls total: {total_live_calls}", file=sys.stderr)
|
| 457 |
-
print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
|
| 458 |
-
print(f"{'━' * 40}\n", file=sys.stderr)
|
| 459 |
-
coverage[(persona, scenario_id)] += 1
|
| 460 |
-
kept_reason_counts[reason] += 1
|
| 461 |
-
else:
|
| 462 |
-
discarded += 1
|
| 463 |
-
_live_d, _fall_d = get_and_reset_counts()
|
| 464 |
-
total_live_calls += _live_d
|
| 465 |
-
total_fallback_calls += _fall_d
|
| 466 |
-
if _verbose:
|
| 467 |
-
print(
|
| 468 |
-
f"[EP --/{args.episodes:02d}] "
|
| 469 |
-
f"{persona}×{scenario_id:<27s} | "
|
| 470 |
-
f"reward={record.get('reward', 0.0):+.2f} | "
|
| 471 |
-
f"eff={record.get('deal_efficiency', 0.0):.3f} | "
|
| 472 |
-
f"kept=NO | "
|
| 473 |
-
f"total_kept={len(kept_records)}/{generated} | "
|
| 474 |
-
f"gemini_live={_live_d} fallback={_fall_d}",
|
| 475 |
-
file=sys.stderr,
|
| 476 |
-
)
|
| 477 |
|
| 478 |
discard_pct = (discarded / max(generated, 1)) * 100.0
|
| 479 |
print(
|
| 480 |
f"Generated: {generated} episodes | Kept: {len(kept_records)} | "
|
| 481 |
-
f"Discarded: {discarded} ({discard_pct:.0f}%)"
|
|
|
|
| 482 |
)
|
| 483 |
reasons_str = ", ".join(f"{reason}={count}" for reason, count in sorted(kept_reason_counts.items()))
|
| 484 |
print(f"Reasons kept: {reasons_str or 'none'}")
|
|
@@ -488,9 +457,9 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 488 |
|
| 489 |
_fallback_rate = 100.0 * total_fallback_calls / max(total_live_calls + total_fallback_calls, 1)
|
| 490 |
_verdict = (
|
| 491 |
-
"ALL CALLS LIVE
|
| 492 |
if _fallback_rate < 5.0
|
| 493 |
-
else "WARNING: fallback rate high
|
| 494 |
)
|
| 495 |
print(f"\nGemini API health:")
|
| 496 |
print(f" Total live calls : {total_live_calls}")
|
|
@@ -501,8 +470,14 @@ async def run_diversity_pass(args, output_path: Path) -> None:
|
|
| 501 |
|
| 502 |
def main() -> None:
|
| 503 |
parser = argparse.ArgumentParser(description="Generate Parlay training data")
|
| 504 |
-
parser.add_argument("--episodes", type=int, default=
|
| 505 |
parser.add_argument("--output", type=str, default="data/episodes.jsonl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
parser.add_argument(
|
| 507 |
"--quality_filter",
|
| 508 |
action="store_true",
|
|
|
|
| 33 |
for scenario in ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
|
| 34 |
]
|
| 35 |
|
| 36 |
+
# Weighted to oversample historically low deal-rate combinations (total weight = 15)
|
| 37 |
+
COMBO_WEIGHTS: dict[tuple[str, str], int] = {
|
| 38 |
+
("veteran", "hiring_package"): 3,
|
| 39 |
+
("veteran", "saas_enterprise"): 2,
|
| 40 |
+
("veteran", "acquisition_term_sheet"): 2,
|
| 41 |
+
("shark", "hiring_package"): 2,
|
| 42 |
+
("diplomat", "hiring_package"): 2,
|
| 43 |
+
("shark", "saas_enterprise"): 1,
|
| 44 |
+
("shark", "acquisition_term_sheet"): 1,
|
| 45 |
+
("diplomat", "saas_enterprise"): 1,
|
| 46 |
+
("diplomat", "acquisition_term_sheet"): 1,
|
| 47 |
+
}
|
| 48 |
+
WEIGHTED_COMBO_LIST: list[tuple[str, str]] = []
|
| 49 |
+
for _pair, _weight in COMBO_WEIGHTS.items():
|
| 50 |
+
WEIGHTED_COMBO_LIST.extend([_pair] * _weight)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _row_total_reward(record: dict) -> float | None:
|
| 54 |
+
v = record.get("reward")
|
| 55 |
+
if v is not None:
|
| 56 |
+
return float(v)
|
| 57 |
+
v2 = record.get("cumulative_reward")
|
| 58 |
+
if v2 is not None:
|
| 59 |
+
return float(v2)
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
|
| 63 |
def is_quality_episode(grade, args) -> tuple[bool, str]:
|
| 64 |
"""
|
|
|
|
| 334 |
|
| 335 |
async def run_diversity_pass(args, output_path: Path) -> None:
|
| 336 |
"""
|
| 337 |
+
Generate a quality-filtered dataset; persona x scenario is weighted-sampled
|
| 338 |
+
(see COMBO_WEIGHTS / WEIGHTED_COMBO_LIST).
|
| 339 |
"""
|
| 340 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 341 |
coverage: dict[tuple[str, str], int] = defaultdict(int)
|
|
|
|
| 343 |
kept_records: list[dict] = []
|
| 344 |
generated = 0
|
| 345 |
discarded = 0
|
| 346 |
+
skipped_min_reward = 0
|
| 347 |
seed = 0
|
|
|
|
| 348 |
total_live_calls: int = 0
|
| 349 |
total_fallback_calls: int = 0
|
| 350 |
_verbose = not getattr(args, "quiet", False)
|
| 351 |
+
_checkpoints = {20, 40, 60, 80, 100, 120, 140}
|
| 352 |
+
|
| 353 |
+
def _emit_checkpoint(_ep_num: int) -> None:
|
| 354 |
+
if not _verbose or _ep_num not in _checkpoints:
|
| 355 |
+
return
|
| 356 |
+
_all_rewards = [r.get("reward", 0.0) for r in kept_records]
|
| 357 |
+
_all_eff = [r.get("deal_efficiency", 0.0) for r in kept_records]
|
| 358 |
+
_combos_covered = len({(r["persona"], r["scenario_id"]) for r in kept_records})
|
| 359 |
+
print(f"\n{'━' * 40}", file=sys.stderr)
|
| 360 |
+
print(f"[CHECKPOINT {_ep_num}/{args.episodes}]", file=sys.stderr)
|
| 361 |
+
print(
|
| 362 |
+
f" Kept so far : {_ep_num}/{generated} ({100 * _ep_num / max(generated, 1):.1f}%)",
|
| 363 |
+
file=sys.stderr,
|
| 364 |
+
)
|
| 365 |
+
print(f" Mean reward : {statistics.mean(_all_rewards):.2f}", file=sys.stderr)
|
| 366 |
+
print(f" Mean efficiency : {statistics.mean(_all_eff):.3f}", file=sys.stderr)
|
| 367 |
+
print(f" Combos covered : {_combos_covered}/9", file=sys.stderr)
|
| 368 |
+
print(f" Min-reward skip : {skipped_min_reward}", file=sys.stderr)
|
| 369 |
+
print(f" Live calls total: {total_live_calls}", file=sys.stderr)
|
| 370 |
+
print(f" Fallback total : {total_fallback_calls}", file=sys.stderr)
|
| 371 |
+
print(f"{'━' * 40}\n", file=sys.stderr)
|
| 372 |
|
| 373 |
with open(output_path, "w", encoding="utf-8") as out_f:
|
| 374 |
while len(kept_records) < args.episodes:
|
| 375 |
+
persona, scenario_id = random.choice(WEIGHTED_COMBO_LIST)
|
| 376 |
+
record = await _run_one(persona, scenario_id, seed=seed, max_turns=args.max_turns)
|
| 377 |
+
seed += 1
|
| 378 |
+
generated += 1
|
| 379 |
+
if record is None:
|
| 380 |
+
_live_n, _fall_n = get_and_reset_counts()
|
| 381 |
+
total_live_calls += _live_n
|
| 382 |
+
total_fallback_calls += _fall_n
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
rw = _row_total_reward(record)
|
| 386 |
+
if rw is not None and rw < args.min_reward:
|
| 387 |
+
skipped_min_reward += 1
|
| 388 |
+
_live_m, _fall_m = get_and_reset_counts()
|
| 389 |
+
total_live_calls += _live_m
|
| 390 |
+
total_fallback_calls += _fall_m
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
if _verbose:
|
|
|
|
|
|
|
|
|
|
| 392 |
print(
|
| 393 |
+
f"[min_reward skip #{skipped_min_reward}] {persona} x {scenario_id} "
|
| 394 |
+
f"reward={rw:.2f} < {args.min_reward}",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
file=sys.stderr,
|
| 396 |
)
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
keep, reason = is_quality_episode(
|
| 400 |
+
_grade_proxy_from_record(record),
|
| 401 |
+
args,
|
| 402 |
+
)
|
| 403 |
+
if not keep:
|
| 404 |
+
discarded += 1
|
| 405 |
+
_live_d, _fall_d = get_and_reset_counts()
|
| 406 |
+
total_live_calls += _live_d
|
| 407 |
+
total_fallback_calls += _fall_d
|
| 408 |
+
if _verbose:
|
| 409 |
+
print(
|
| 410 |
+
f"[EP --/{args.episodes:02d}] "
|
| 411 |
+
f"{persona}×{scenario_id:<27s} | "
|
| 412 |
+
f"reward={record.get('reward', 0.0):+.2f} | "
|
| 413 |
+
f"eff={record.get('deal_efficiency', 0.0):.3f} | "
|
| 414 |
+
f"kept=NO | "
|
| 415 |
+
f"total_kept={len(kept_records)}/{generated} | "
|
| 416 |
+
f"gemini_live={_live_d} fallback={_fall_d}",
|
| 417 |
+
file=sys.stderr,
|
| 418 |
+
)
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
out_f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 422 |
+
out_f.flush()
|
| 423 |
+
kept_records.append(record)
|
| 424 |
+
_live, _fall = get_and_reset_counts()
|
| 425 |
+
total_live_calls += _live
|
| 426 |
+
total_fallback_calls += _fall
|
| 427 |
+
_ep_num = len(kept_records)
|
| 428 |
+
if _verbose:
|
| 429 |
+
_reward = record.get("reward", 0.0)
|
| 430 |
+
_eff = record.get("deal_efficiency", 0.0)
|
| 431 |
+
_combo = f"{record['persona']}×{record['scenario_id']}"
|
| 432 |
+
print(
|
| 433 |
+
f"[EP {_ep_num:02d}/{args.episodes:02d}] "
|
| 434 |
+
f"{_combo:<35s} | "
|
| 435 |
+
f"reward={_reward:+.2f} | "
|
| 436 |
+
f"eff={_eff:.3f} | "
|
| 437 |
+
f"kept=YES | "
|
| 438 |
+
f"total_kept={_ep_num}/{generated} | "
|
| 439 |
+
f"gemini_live={_live} fallback={_fall}",
|
| 440 |
+
file=sys.stderr,
|
| 441 |
)
|
| 442 |
+
_emit_checkpoint(_ep_num)
|
| 443 |
+
coverage[(persona, scenario_id)] += 1
|
| 444 |
+
kept_reason_counts[reason] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
discard_pct = (discarded / max(generated, 1)) * 100.0
|
| 447 |
print(
|
| 448 |
f"Generated: {generated} episodes | Kept: {len(kept_records)} | "
|
| 449 |
+
f"Discarded: {discarded} ({discard_pct:.0f}%) | "
|
| 450 |
+
f"Skipped (min_reward < {args.min_reward}): {skipped_min_reward}"
|
| 451 |
)
|
| 452 |
reasons_str = ", ".join(f"{reason}={count}" for reason, count in sorted(kept_reason_counts.items()))
|
| 453 |
print(f"Reasons kept: {reasons_str or 'none'}")
|
|
|
|
| 457 |
|
| 458 |
_fallback_rate = 100.0 * total_fallback_calls / max(total_live_calls + total_fallback_calls, 1)
|
| 459 |
_verdict = (
|
| 460 |
+
"ALL CALLS LIVE - data is real"
|
| 461 |
if _fallback_rate < 5.0
|
| 462 |
+
else "WARNING: fallback rate high - check API key and rate limits"
|
| 463 |
)
|
| 464 |
print(f"\nGemini API health:")
|
| 465 |
print(f" Total live calls : {total_live_calls}")
|
|
|
|
| 470 |
|
| 471 |
def main() -> None:
|
| 472 |
parser = argparse.ArgumentParser(description="Generate Parlay training data")
|
| 473 |
+
parser.add_argument("--episodes", type=int, default=140)
|
| 474 |
parser.add_argument("--output", type=str, default="data/episodes.jsonl")
|
| 475 |
+
parser.add_argument(
|
| 476 |
+
"--min-reward",
|
| 477 |
+
type=float,
|
| 478 |
+
default=-50.0,
|
| 479 |
+
help="After grading, do not write episodes with total reward below this (default: -50.0)",
|
| 480 |
+
)
|
| 481 |
parser.add_argument(
|
| 482 |
"--quality_filter",
|
| 483 |
action="store_true",
|
training/grpo_train.py
CHANGED
|
@@ -18,18 +18,32 @@ from pathlib import Path
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
GRPO_STEPS = int(os.getenv("GRPO_STEPS", "500"))
|
| 23 |
GRPO_GENERATIONS = int(os.getenv("GRPO_GENERATIONS", "8"))
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
Build GRPO dataset. Each record needs only a 'prompt' field plus metadata.
|
| 29 |
The model generates G=8 completions per prompt; grader scores all 8.
|
| 30 |
|
| 31 |
Args:
|
| 32 |
jsonl_path: Path to the JSONL episodes file.
|
|
|
|
|
|
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
HuggingFace Dataset with prompt + metadata columns.
|
|
@@ -40,21 +54,31 @@ def build_grpo_dataset(jsonl_path: str):
|
|
| 40 |
raise ImportError("Install datasets: pip install datasets") from exc
|
| 41 |
|
| 42 |
prompts = []
|
|
|
|
| 43 |
with open(jsonl_path, encoding="utf-8") as f:
|
| 44 |
for line in f:
|
| 45 |
rec = json.loads(line.strip())
|
| 46 |
-
if rec.get("split") =
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"prompt": rec["prompt"],
|
| 50 |
"scenario_id": rec.get("scenario_id", ""),
|
| 51 |
"persona": rec.get("persona", ""),
|
| 52 |
-
# Reward function kwargs (passed through dataset)
|
| 53 |
"batna_seller": _get_batna(rec.get("scenario_id", ""), "seller"),
|
| 54 |
-
"batna_buyer":
|
| 55 |
-
"zopa_width":
|
| 56 |
-
}
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
return Dataset.from_list(prompts)
|
| 59 |
|
| 60 |
|
|
@@ -62,7 +86,7 @@ def _get_batna(scenario_id: str, side: str) -> float:
|
|
| 62 |
"""Lookup BATNA for a scenario without importing game module at training time."""
|
| 63 |
batnas: dict[str, dict[str, float]] = {
|
| 64 |
"saas_enterprise": {"seller": 125_000, "buyer": 165_000},
|
| 65 |
-
"hiring_package": {"seller": 195_000, "buyer":
|
| 66 |
"acquisition_term_sheet": {"seller": 10_500_000, "buyer": 16_000_000},
|
| 67 |
}
|
| 68 |
return float(batnas.get(scenario_id, {}).get(side, 0))
|
|
@@ -80,6 +104,7 @@ def train_grpo(
|
|
| 80 |
data_path: str,
|
| 81 |
output_dir: str,
|
| 82 |
steps: int = 500,
|
|
|
|
| 83 |
) -> None:
|
| 84 |
"""
|
| 85 |
GRPO training loop.
|
|
@@ -117,7 +142,7 @@ def train_grpo(
|
|
| 117 |
format_reward,
|
| 118 |
)
|
| 119 |
|
| 120 |
-
dataset = build_grpo_dataset(data_path)
|
| 121 |
if len(dataset) == 0:
|
| 122 |
raise ValueError("Empty GRPO dataset. Run generate_data.py first.")
|
| 123 |
|
|
@@ -181,6 +206,12 @@ def main() -> None:
|
|
| 181 |
parser.add_argument("--model", default="models/parlay-sft")
|
| 182 |
parser.add_argument("--base_model", default="")
|
| 183 |
parser.add_argument("--data", default="data/episodes.jsonl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
parser.add_argument("--output", default="models/parlay-grpo")
|
| 185 |
parser.add_argument("--steps", type=int, default=GRPO_STEPS)
|
| 186 |
parser.add_argument("--g", type=int, default=GRPO_GENERATIONS)
|
|
@@ -191,7 +222,7 @@ def main() -> None:
|
|
| 191 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 192 |
GRPO_GENERATIONS = args.g
|
| 193 |
model_path = args.base_model or args.model
|
| 194 |
-
train_grpo(model_path, args.data, args.output, args.steps)
|
| 195 |
|
| 196 |
if args.save_curves:
|
| 197 |
curves_path = Path(args.save_curves)
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
+
# SFT->GRPO pipeline: set BASE_MODEL=checkpoints/sft_1.5b/ after sft_train.py
|
| 22 |
+
# (overridable via BASE_MODEL env var)
|
| 23 |
+
BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")
|
| 24 |
GRPO_STEPS = int(os.getenv("GRPO_STEPS", "500"))
|
| 25 |
GRPO_GENERATIONS = int(os.getenv("GRPO_GENERATIONS", "8"))
|
| 26 |
|
| 27 |
|
| 28 |
+
def _row_total_reward(rec: dict) -> float | None:
|
| 29 |
+
v = rec.get("reward")
|
| 30 |
+
if v is not None:
|
| 31 |
+
return float(v)
|
| 32 |
+
v2 = rec.get("cumulative_reward")
|
| 33 |
+
if v2 is not None:
|
| 34 |
+
return float(v2)
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def build_grpo_dataset(jsonl_path: str, min_reward: float = -50.0):
|
| 39 |
"""
|
| 40 |
Build GRPO dataset. Each record needs only a 'prompt' field plus metadata.
|
| 41 |
The model generates G=8 completions per prompt; grader scores all 8.
|
| 42 |
|
| 43 |
Args:
|
| 44 |
jsonl_path: Path to the JSONL episodes file.
|
| 45 |
+
min_reward: Drop train rows with total reward (reward / cumulative_reward) below this
|
| 46 |
+
(missing reward fields are kept for backward compatibility).
|
| 47 |
|
| 48 |
Returns:
|
| 49 |
HuggingFace Dataset with prompt + metadata columns.
|
|
|
|
| 54 |
raise ImportError("Install datasets: pip install datasets") from exc
|
| 55 |
|
| 56 |
prompts = []
|
| 57 |
+
filtered = 0
|
| 58 |
with open(jsonl_path, encoding="utf-8") as f:
|
| 59 |
for line in f:
|
| 60 |
rec = json.loads(line.strip())
|
| 61 |
+
if rec.get("split") != "train":
|
| 62 |
+
continue
|
| 63 |
+
r = _row_total_reward(rec)
|
| 64 |
+
if r is not None and r < min_reward:
|
| 65 |
+
filtered += 1
|
| 66 |
+
continue
|
| 67 |
+
# Extract ZOPA metadata for reward functions
|
| 68 |
+
prompts.append(
|
| 69 |
+
{
|
| 70 |
"prompt": rec["prompt"],
|
| 71 |
"scenario_id": rec.get("scenario_id", ""),
|
| 72 |
"persona": rec.get("persona", ""),
|
|
|
|
| 73 |
"batna_seller": _get_batna(rec.get("scenario_id", ""), "seller"),
|
| 74 |
+
"batna_buyer": _get_batna(rec.get("scenario_id", ""), "buyer"),
|
| 75 |
+
"zopa_width": _get_zopa_width(rec.get("scenario_id", "")),
|
| 76 |
+
}
|
| 77 |
+
)
|
| 78 |
+
print(
|
| 79 |
+
f"Filtered {filtered} records below min_reward={min_reward}, "
|
| 80 |
+
f"{len(prompts)} remaining for GRPO"
|
| 81 |
+
)
|
| 82 |
return Dataset.from_list(prompts)
|
| 83 |
|
| 84 |
|
|
|
|
| 86 |
"""Lookup BATNA for a scenario without importing game module at training time."""
|
| 87 |
batnas: dict[str, dict[str, float]] = {
|
| 88 |
"saas_enterprise": {"seller": 125_000, "buyer": 165_000},
|
| 89 |
+
"hiring_package": {"seller": 195_000, "buyer": 264_500}, # match game/scenarios (widened zopa)
|
| 90 |
"acquisition_term_sheet": {"seller": 10_500_000, "buyer": 16_000_000},
|
| 91 |
}
|
| 92 |
return float(batnas.get(scenario_id, {}).get(side, 0))
|
|
|
|
| 104 |
data_path: str,
|
| 105 |
output_dir: str,
|
| 106 |
steps: int = 500,
|
| 107 |
+
min_reward: float = -50.0,
|
| 108 |
) -> None:
|
| 109 |
"""
|
| 110 |
GRPO training loop.
|
|
|
|
| 142 |
format_reward,
|
| 143 |
)
|
| 144 |
|
| 145 |
+
dataset = build_grpo_dataset(data_path, min_reward=min_reward)
|
| 146 |
if len(dataset) == 0:
|
| 147 |
raise ValueError("Empty GRPO dataset. Run generate_data.py first.")
|
| 148 |
|
|
|
|
| 206 |
parser.add_argument("--model", default="models/parlay-sft")
|
| 207 |
parser.add_argument("--base_model", default="")
|
| 208 |
parser.add_argument("--data", default="data/episodes.jsonl")
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--min-reward",
|
| 211 |
+
type=float,
|
| 212 |
+
default=-50.0,
|
| 213 |
+
help="Skip JSONL train rows with total reward below this (default: -50.0)",
|
| 214 |
+
)
|
| 215 |
parser.add_argument("--output", default="models/parlay-grpo")
|
| 216 |
parser.add_argument("--steps", type=int, default=GRPO_STEPS)
|
| 217 |
parser.add_argument("--g", type=int, default=GRPO_GENERATIONS)
|
|
|
|
| 222 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 223 |
GRPO_GENERATIONS = args.g
|
| 224 |
model_path = args.base_model or args.model
|
| 225 |
+
train_grpo(model_path, args.data, args.output, args.steps, min_reward=args.min_reward)
|
| 226 |
|
| 227 |
if args.save_curves:
|
| 228 |
curves_path = Path(args.save_curves)
|
training/random_baseline.py
CHANGED
|
@@ -1,125 +1,138 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Random-policy baseline for Parlay.
|
| 3 |
-
Runs N episodes with purely random move selection (no Gemini API — always
|
| 4 |
-
uses mock mode) and writes a summary JSON that the training notebook uses
|
| 5 |
-
to benchmark SFT / GRPO improvement.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
python training/random_baseline.py
|
| 9 |
-
python training/random_baseline.py --episodes 20 --output data/random_baseline.json
|
| 10 |
-
"""
|
| 11 |
import argparse
|
| 12 |
import asyncio
|
| 13 |
import json
|
| 14 |
import logging
|
| 15 |
-
import os
|
| 16 |
import random
|
| 17 |
-
import statistics
|
| 18 |
-
import sys
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# Force mock mode — random baseline never calls the real Gemini API
|
| 25 |
-
os.environ.pop("GOOGLE_API_KEY", None)
|
| 26 |
-
|
| 27 |
-
from agent.runner import run_episode
|
| 28 |
-
from game.scenarios import SCENARIOS
|
| 29 |
-
from parlay_env.models import PersonaType
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
]
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
async def _run_baseline(episodes: int) -> list[dict]:
|
| 41 |
-
|
| 42 |
-
results = []
|
| 43 |
-
seed = 0
|
| 44 |
for i in range(episodes):
|
| 45 |
-
|
|
|
|
| 46 |
try:
|
| 47 |
-
|
| 48 |
-
persona=PersonaType(persona_str),
|
| 49 |
-
scenario_id=scenario_id,
|
| 50 |
-
inject_noise=True, # random noise simulates random policy
|
| 51 |
-
force_drift=random.random() < 0.4,
|
| 52 |
-
seed=seed,
|
| 53 |
-
max_turns=14,
|
| 54 |
-
)
|
| 55 |
-
results.append({
|
| 56 |
-
"persona": persona_str,
|
| 57 |
-
"scenario_id": scenario_id,
|
| 58 |
-
"reward": result.grade.total_reward,
|
| 59 |
-
"deal_efficiency": result.grade.deal_efficiency,
|
| 60 |
-
"deal_reached": result.final_price is not None,
|
| 61 |
-
"tom_accuracy_avg": result.grade.tom_accuracy_avg,
|
| 62 |
-
"drift_adapted": result.grade.drift_adapted,
|
| 63 |
-
"termination_reason": result.grade.termination_reason,
|
| 64 |
-
})
|
| 65 |
except Exception as exc:
|
| 66 |
-
logger.warning("Baseline episode %d failed (%s
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
return {
|
| 81 |
-
"
|
| 82 |
-
"
|
| 83 |
-
"
|
| 84 |
-
"
|
| 85 |
-
"
|
| 86 |
-
"
|
| 87 |
-
"
|
| 88 |
-
"drift_adapted_rate": round(drift_count / len(results), 4),
|
| 89 |
-
"policy": "random_mock",
|
| 90 |
-
"note": (
|
| 91 |
-
"Baseline uses Parlay mock responses (no real Gemini API). "
|
| 92 |
-
"Compare mean_reward and mean_efficiency against SFT/GRPO runs."
|
| 93 |
-
),
|
| 94 |
}
|
| 95 |
|
| 96 |
|
| 97 |
def main() -> None:
|
| 98 |
-
parser = argparse.ArgumentParser(description="Parlay random
|
| 99 |
-
parser.add_argument("--episodes", type=int, default=
|
| 100 |
-
|
| 101 |
-
parser.add_argument("--output", type=str, default="data/random_baseline.json",
|
| 102 |
-
help="Output path for the baseline JSON summary")
|
| 103 |
args = parser.parse_args()
|
| 104 |
|
| 105 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
print(f"Running {args.episodes} random-policy episodes (mock mode, no API key)…")
|
| 109 |
-
results = asyncio.run(_run_baseline(args.episodes))
|
| 110 |
-
|
| 111 |
-
summary = _summarise(results)
|
| 112 |
|
| 113 |
out_path = Path(args.output)
|
| 114 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
print(f"\nBaseline complete ({summary['n_episodes']} episodes):")
|
| 119 |
-
print(f" Mean reward : {summary.get('mean_reward', 'N/A')}")
|
| 120 |
-
print(f" Mean efficiency : {summary.get('mean_efficiency', 'N/A')}")
|
| 121 |
-
print(f" Deal rate : {summary.get('deal_rate', 'N/A'):.1%}")
|
| 122 |
-
print(f" Written to : {out_path.resolve()}")
|
| 123 |
|
| 124 |
|
| 125 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""Random-action baseline for Parlay."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import argparse
|
| 3 |
import asyncio
|
| 4 |
import json
|
| 5 |
import logging
|
|
|
|
| 6 |
import random
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
|
| 9 |
+
from parlay_env.grader import grade_episode
|
| 10 |
+
from parlay_env.models import TacticalMove
|
| 11 |
+
from parlay_env.server import _handle_reset, _handle_step, get_session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
+
PERSONAS = ["shark", "diplomat", "veteran"]
|
| 16 |
+
SCENARIOS = ["saas_enterprise", "hiring_package", "acquisition_term_sheet"]
|
| 17 |
+
RANDOM_LINES = [
|
| 18 |
+
"Let's keep talking.",
|
| 19 |
+
"I can move a bit.",
|
| 20 |
+
"This is my proposal.",
|
| 21 |
+
"We should find middle ground.",
|
| 22 |
+
"Given that context, here's my number.",
|
| 23 |
]
|
| 24 |
|
| 25 |
|
| 26 |
+
def _mean(values: list[float]) -> float:
|
| 27 |
+
return sum(values) / len(values) if values else 0.0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
async def _run_single_episode(scenario_id: str, persona: str, seed: int) -> dict:
|
| 31 |
+
random.seed(seed)
|
| 32 |
+
reset = await _handle_reset({"scenario_id": scenario_id, "persona": persona, "seed": seed})
|
| 33 |
+
session_id = str(reset["session_id"])
|
| 34 |
+
final_price = None
|
| 35 |
+
t_close = None
|
| 36 |
+
done = False
|
| 37 |
+
|
| 38 |
+
while not done:
|
| 39 |
+
state = get_session_state(session_id)
|
| 40 |
+
if state is None:
|
| 41 |
+
break
|
| 42 |
+
if state.episode_done:
|
| 43 |
+
break
|
| 44 |
+
|
| 45 |
+
low = state.hidden_state.walk_away_price
|
| 46 |
+
high = state.hidden_state.budget_ceiling
|
| 47 |
+
offer = round(random.uniform(low, high), 2)
|
| 48 |
+
|
| 49 |
+
moves: list[TacticalMove | None] = [None]
|
| 50 |
+
if state.credibility_points >= 0:
|
| 51 |
+
moves.append(TacticalMove.ANCHOR_HIGH)
|
| 52 |
+
if state.credibility_points >= 5:
|
| 53 |
+
moves.append(TacticalMove.SILENCE)
|
| 54 |
+
if state.credibility_points >= 20:
|
| 55 |
+
moves.append(TacticalMove.BATNA_REVEAL)
|
| 56 |
+
move = random.choice(moves)
|
| 57 |
+
|
| 58 |
+
payload = {
|
| 59 |
+
"session_id": session_id,
|
| 60 |
+
"action": {
|
| 61 |
+
"utterance": random.choice(RANDOM_LINES),
|
| 62 |
+
"offer_amount": offer,
|
| 63 |
+
"tactical_move": move.value if move else None,
|
| 64 |
+
},
|
| 65 |
+
}
|
| 66 |
+
step = await _handle_step(payload)
|
| 67 |
+
done = bool(step.get("done", False))
|
| 68 |
+
|
| 69 |
+
state = get_session_state(session_id)
|
| 70 |
+
if state and state.deal_reached and final_price is None:
|
| 71 |
+
final_price = offer
|
| 72 |
+
t_close = state.step_count
|
| 73 |
+
|
| 74 |
+
state = get_session_state(session_id)
|
| 75 |
+
if state is None:
|
| 76 |
+
raise RuntimeError(f"Missing session state for {session_id}")
|
| 77 |
+
grade = grade_episode(state, final_price=final_price, t_close=t_close, t_max=20)
|
| 78 |
+
return {
|
| 79 |
+
"avg_reward": float(grade.total_reward),
|
| 80 |
+
"deal_rate": 1.0 if final_price is not None else 0.0,
|
| 81 |
+
"avg_efficiency": float(grade.deal_efficiency),
|
| 82 |
+
"avg_tom_accuracy": float(grade.tom_accuracy_avg),
|
| 83 |
+
"bluffs_caught": int(grade.bluffs_caught),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
async def _run_baseline(episodes: int) -> list[dict]:
|
| 88 |
+
rows: list[dict] = []
|
|
|
|
|
|
|
| 89 |
for i in range(episodes):
|
| 90 |
+
persona = PERSONAS[i % len(PERSONAS)]
|
| 91 |
+
scenario = SCENARIOS[(i // len(PERSONAS)) % len(SCENARIOS)]
|
| 92 |
try:
|
| 93 |
+
rows.append(await _run_single_episode(scenario, persona, i + 7))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
except Exception as exc:
|
| 95 |
+
logger.warning("Baseline episode %d failed (%s/%s): %s", i + 1, scenario, persona, exc)
|
| 96 |
+
return rows
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _summarise(rows: list[dict], episodes_requested: int) -> dict:
|
| 100 |
+
if not rows:
|
| 101 |
+
return {
|
| 102 |
+
"episodes_requested": episodes_requested,
|
| 103 |
+
"episodes_completed": 0,
|
| 104 |
+
"avg_reward": 0.0,
|
| 105 |
+
"deal_rate": 0.0,
|
| 106 |
+
"avg_efficiency": 0.0,
|
| 107 |
+
"avg_tom_accuracy": 0.0,
|
| 108 |
+
"bluffs_caught": 0,
|
| 109 |
+
}
|
| 110 |
return {
|
| 111 |
+
"episodes_requested": episodes_requested,
|
| 112 |
+
"episodes_completed": len(rows),
|
| 113 |
+
"avg_reward": round(_mean([r["avg_reward"] for r in rows]), 4),
|
| 114 |
+
"deal_rate": round(_mean([r["deal_rate"] for r in rows]), 4),
|
| 115 |
+
"avg_efficiency": round(_mean([r["avg_efficiency"] for r in rows]), 4),
|
| 116 |
+
"avg_tom_accuracy": round(_mean([r["avg_tom_accuracy"] for r in rows]), 4),
|
| 117 |
+
"bluffs_caught": int(sum(r["bluffs_caught"] for r in rows)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
}
|
| 119 |
|
| 120 |
|
| 121 |
def main() -> None:
|
| 122 |
+
parser = argparse.ArgumentParser(description="Parlay random baseline")
|
| 123 |
+
parser.add_argument("--episodes", type=int, default=20)
|
| 124 |
+
parser.add_argument("--output", default="results/baseline.json")
|
|
|
|
|
|
|
| 125 |
args = parser.parse_args()
|
| 126 |
|
| 127 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 128 |
+
rows = asyncio.run(_run_baseline(args.episodes))
|
| 129 |
+
summary = _summarise(rows, args.episodes)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
out_path = Path(args.output)
|
| 132 |
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 133 |
+
out_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 134 |
+
print(json.dumps(summary, indent=2))
|
| 135 |
+
print(f"\nSaved random baseline to {out_path.resolve()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
if __name__ == "__main__":
|
training/sft_train.py
CHANGED
|
@@ -1,177 +1,119 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
Applies episode quality filters (offers + reward outliers) and stable SFT target
|
| 6 |
-
metadata (log-scaled efficiency, clipped reward) when building training text.
|
| 7 |
-
|
| 8 |
-
Usage:
|
| 9 |
-
python -m training.sft_train \
|
| 10 |
-
--data data/episodes.jsonl \
|
| 11 |
-
--model Qwen/Qwen2.5-7B-Instruct \
|
| 12 |
-
--output models/parlay-sft \
|
| 13 |
-
--threshold 0.30
|
| 14 |
"""
|
| 15 |
import argparse
|
| 16 |
import json
|
| 17 |
import logging
|
| 18 |
-
import os
|
| 19 |
from pathlib import Path
|
| 20 |
|
| 21 |
-
from .episode_filters import (
|
| 22 |
-
SFTFilterConfig,
|
| 23 |
-
clip_reward_for_label,
|
| 24 |
-
efficiency_sft_label,
|
| 25 |
-
episode_passes_sft_filters,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
def load_sft_dataset(
|
| 35 |
-
jsonl_path: Path,
|
| 36 |
-
threshold: float = 0.30,
|
| 37 |
-
filter_cfg: SFTFilterConfig | None = None,
|
| 38 |
-
include_sft_targets: bool = True,
|
| 39 |
-
):
|
| 40 |
-
"""
|
| 41 |
-
Load episodes above efficiency threshold and format for SFT.
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
Args:
|
| 48 |
-
jsonl_path: Path to the JSONL episodes file.
|
| 49 |
-
threshold: Minimum deal_efficiency to include.
|
| 50 |
-
filter_cfg: Drop/clip thresholds; default SFTFilterConfig().
|
| 51 |
-
include_sft_targets: If True, embed eff_log and reward_clip in the example text.
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
"""
|
| 56 |
try:
|
| 57 |
from datasets import Dataset
|
| 58 |
except ImportError as exc:
|
| 59 |
raise ImportError("Install datasets: pip install datasets") from exc
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
if not
|
| 69 |
-
skipped_filter += 1
|
| 70 |
continue
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
history_lines.append(f"{role}: {content}")
|
| 110 |
-
history = "\n".join(history_lines)
|
| 111 |
-
|
| 112 |
-
targets_block = ""
|
| 113 |
-
if include_sft_targets:
|
| 114 |
-
targets_block = (
|
| 115 |
-
f"<|sft_targets|>eff_log={efficiency_label:.4f} "
|
| 116 |
-
f"reward_clip={reward_clip:.2f}</s>\n"
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
return (
|
| 120 |
-
f"<|system|>{system}</s>\n"
|
| 121 |
-
f"{targets_block}"
|
| 122 |
-
f"<|negotiation_history|>{history}</s>\n"
|
| 123 |
-
f"<|assistant|>{response}</s>"
|
| 124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def train_sft(
|
| 128 |
-
data_path: Path,
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
threshold: float = 0.30,
|
| 132 |
-
filter_cfg: SFTFilterConfig | None = None,
|
| 133 |
-
include_sft_targets: bool = True,
|
| 134 |
-
) -> Path:
|
| 135 |
-
"""
|
| 136 |
-
Run SFT fine-tuning.
|
| 137 |
-
|
| 138 |
-
Args:
|
| 139 |
-
data_path: Path to episodes JSONL.
|
| 140 |
-
model_id: HuggingFace model ID or local path.
|
| 141 |
-
output_dir: Where to save the trained model.
|
| 142 |
-
threshold: Efficiency filter for training data.
|
| 143 |
-
filter_cfg: Quality filter / clip config.
|
| 144 |
-
include_sft_targets: Embed normalized targets in training strings.
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
output_dir path.
|
| 148 |
-
"""
|
| 149 |
import torch
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
try:
|
| 154 |
-
from peft import LoraConfig
|
| 155 |
-
from trl import SFTTrainer, SFTConfig
|
| 156 |
-
except ImportError as exc:
|
| 157 |
-
raise ImportError("Install: pip install trl peft") from exc
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
data_path, threshold, filter_cfg, include_sft_targets=include_sft_targets
|
| 162 |
-
)
|
| 163 |
-
if len(dataset) == 0 and threshold > 0.0:
|
| 164 |
-
logger.warning(
|
| 165 |
-
f"No episodes above threshold {threshold}. Lowering to 0.0 (all train rows)."
|
| 166 |
-
)
|
| 167 |
-
dataset = load_sft_dataset(
|
| 168 |
-
data_path, 0.0, filter_cfg, include_sft_targets=include_sft_targets
|
| 169 |
-
)
|
| 170 |
-
if len(dataset) == 0:
|
| 171 |
-
raise RuntimeError(
|
| 172 |
-
"SFT dataset is empty. "
|
| 173 |
-
"Run generate_data.py first with --episodes >= 200"
|
| 174 |
-
)
|
| 175 |
|
| 176 |
lora_config = LoraConfig(
|
| 177 |
r=16,
|
|
@@ -188,16 +130,16 @@ def train_sft(
|
|
| 188 |
per_device_train_batch_size=4,
|
| 189 |
gradient_accumulation_steps=4,
|
| 190 |
learning_rate=2e-4,
|
| 191 |
-
warmup_ratio=0.05,
|
| 192 |
-
lr_scheduler_type="cosine",
|
| 193 |
logging_steps=10,
|
| 194 |
save_strategy="epoch",
|
| 195 |
-
|
| 196 |
-
bf16=torch.cuda.is_available(),
|
| 197 |
-
max_seq_length=2048,
|
| 198 |
report_to="none",
|
|
|
|
| 199 |
)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
| 201 |
trainer = SFTTrainer(
|
| 202 |
model=model_id,
|
| 203 |
args=training_args,
|
|
@@ -205,46 +147,27 @@ def train_sft(
|
|
| 205 |
peft_config=lora_config,
|
| 206 |
)
|
| 207 |
|
| 208 |
-
logger.info(
|
| 209 |
trainer.train()
|
| 210 |
trainer.save_model(str(output_dir))
|
| 211 |
-
logger.info(
|
| 212 |
-
return output_dir
|
| 213 |
|
| 214 |
|
| 215 |
def main() -> None:
|
| 216 |
-
parser = argparse.ArgumentParser(description="Parlay SFT
|
| 217 |
parser.add_argument("--data", default="data/episodes.jsonl")
|
| 218 |
-
parser.add_argument("--model", default=
|
| 219 |
-
parser.add_argument("--output", default=
|
| 220 |
-
parser.add_argument("--steps", type=int, default=0, help="Notebook compatibility flag")
|
| 221 |
-
parser.add_argument("--threshold", type=float, default=TOP_PLAYER_THRESHOLD)
|
| 222 |
-
parser.add_argument("--reward-drop-min", type=float, default=-400.0)
|
| 223 |
-
parser.add_argument("--reward-drop-max", type=float, default=400.0)
|
| 224 |
-
parser.add_argument("--clip-reward-min", type=float, default=-200.0)
|
| 225 |
-
parser.add_argument("--clip-reward-max", type=float, default=200.0)
|
| 226 |
parser.add_argument(
|
| 227 |
-
"--
|
| 228 |
-
|
| 229 |
-
|
|
|
|
| 230 |
)
|
| 231 |
args = parser.parse_args()
|
| 232 |
|
| 233 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 234 |
-
|
| 235 |
-
reward_drop_min=args.reward_drop_min,
|
| 236 |
-
reward_drop_max=args.reward_drop_max,
|
| 237 |
-
clip_reward_min=args.clip_reward_min,
|
| 238 |
-
clip_reward_max=args.clip_reward_max,
|
| 239 |
-
)
|
| 240 |
-
train_sft(
|
| 241 |
-
Path(args.data),
|
| 242 |
-
args.model,
|
| 243 |
-
Path(args.output),
|
| 244 |
-
args.threshold,
|
| 245 |
-
filter_cfg=cfg,
|
| 246 |
-
include_sft_targets=not args.no_sft_targets,
|
| 247 |
-
)
|
| 248 |
|
| 249 |
|
| 250 |
if __name__ == "__main__":
|
|
|
|
| 1 |
"""
|
| 2 |
+
Run before grpo_train.py for SFT→GRPO pipeline. Pass checkpoint path as
|
| 3 |
+
BASE_MODEL env var to grpo_train.py.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
import argparse
|
| 6 |
import json
|
| 7 |
import logging
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
+
DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 13 |
+
DEFAULT_OUTPUT = "checkpoints/sft_1.5b/"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _extract_completions(rec: dict) -> list[str]:
|
| 17 |
+
"""Return candidate completion texts from a record."""
|
| 18 |
+
completion = rec.get("completion")
|
| 19 |
+
if isinstance(completion, str) and completion.strip():
|
| 20 |
+
return [completion.strip()]
|
| 21 |
|
| 22 |
+
conversation = rec.get("conversation", [])
|
| 23 |
+
candidates: list[str] = []
|
| 24 |
+
if isinstance(conversation, list):
|
| 25 |
+
for turn in conversation:
|
| 26 |
+
if not isinstance(turn, dict):
|
| 27 |
+
continue
|
| 28 |
+
role = str(turn.get("role", "")).lower()
|
| 29 |
+
content = str(turn.get("content", "")).strip()
|
| 30 |
+
if role == "negotiator" and content:
|
| 31 |
+
candidates.append(content)
|
| 32 |
+
return candidates
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
def _row_total_reward(rec: dict) -> float | None:
|
| 36 |
+
v = rec.get("reward")
|
| 37 |
+
if v is not None:
|
| 38 |
+
return float(v)
|
| 39 |
+
v2 = rec.get("cumulative_reward")
|
| 40 |
+
if v2 is not None:
|
| 41 |
+
return float(v2)
|
| 42 |
+
return None
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
def load_sft_dataset(data_path: Path, min_reward: float = -50.0):
|
| 46 |
+
"""Build a text dataset from JSONL prompt/completion pairs."""
|
|
|
|
| 47 |
try:
|
| 48 |
from datasets import Dataset
|
| 49 |
except ImportError as exc:
|
| 50 |
raise ImportError("Install datasets: pip install datasets") from exc
|
| 51 |
|
| 52 |
+
rows: list[dict[str, str]] = []
|
| 53 |
+
skipped = 0
|
| 54 |
+
reward_filtered = 0
|
| 55 |
+
remaining_records = 0
|
| 56 |
+
with data_path.open("r", encoding="utf-8") as f:
|
| 57 |
+
for line_no, line in enumerate(f, start=1):
|
| 58 |
+
line = line.strip()
|
| 59 |
+
if not line:
|
|
|
|
| 60 |
continue
|
| 61 |
+
try:
|
| 62 |
+
rec = json.loads(line)
|
| 63 |
+
except json.JSONDecodeError:
|
| 64 |
+
logger.warning("Skipping malformed JSONL row %d", line_no)
|
| 65 |
+
skipped += 1
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
r = _row_total_reward(rec)
|
| 69 |
+
if r is not None and r < min_reward:
|
| 70 |
+
reward_filtered += 1
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
prompt = str(rec.get("prompt", "")).strip()
|
| 74 |
+
if not prompt:
|
| 75 |
+
logger.warning("Skipping row %d: missing prompt", line_no)
|
| 76 |
+
skipped += 1
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
completions = _extract_completions(rec)
|
| 80 |
+
if not completions:
|
| 81 |
+
logger.warning("Skipping row %d: missing completion and negotiator turns", line_no)
|
| 82 |
+
skipped += 1
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
remaining_records += 1
|
| 86 |
+
for completion in completions:
|
| 87 |
+
rows.append(
|
| 88 |
+
{
|
| 89 |
+
"text": (
|
| 90 |
+
f"<|system|>{prompt}</s>\n"
|
| 91 |
+
f"<|assistant|>{completion}</s>"
|
| 92 |
+
)
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
print(
|
| 97 |
+
f"Filtered {reward_filtered} records below min_reward={min_reward}, "
|
| 98 |
+
f"{remaining_records} remaining for SFT"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
)
|
| 100 |
+
if skipped:
|
| 101 |
+
logger.info("Also skipped %d malformed/empty JSONL rows; expanded to %d text rows", skipped, len(rows))
|
| 102 |
+
if not rows:
|
| 103 |
+
raise RuntimeError("No valid SFT examples found in dataset.")
|
| 104 |
+
return Dataset.from_list(rows)
|
| 105 |
|
| 106 |
|
| 107 |
def train_sft(
|
| 108 |
+
data_path: Path, model_id: str, output_dir: Path, min_reward: float = -50.0
|
| 109 |
+
) -> None:
|
| 110 |
+
"""Fine-tune a base model with LoRA via TRL SFTTrainer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
import torch
|
| 112 |
+
from peft import LoraConfig
|
| 113 |
+
from trl import SFTConfig, SFTTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
dataset = load_sft_dataset(data_path, min_reward=min_reward)
|
| 116 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
lora_config = LoraConfig(
|
| 119 |
r=16,
|
|
|
|
| 130 |
per_device_train_batch_size=4,
|
| 131 |
gradient_accumulation_steps=4,
|
| 132 |
learning_rate=2e-4,
|
|
|
|
|
|
|
| 133 |
logging_steps=10,
|
| 134 |
save_strategy="epoch",
|
| 135 |
+
fp16=True,
|
|
|
|
|
|
|
| 136 |
report_to="none",
|
| 137 |
+
max_seq_length=2048,
|
| 138 |
)
|
| 139 |
|
| 140 |
+
if not torch.cuda.is_available():
|
| 141 |
+
logger.warning("No CUDA GPU detected; training may be very slow.")
|
| 142 |
+
|
| 143 |
trainer = SFTTrainer(
|
| 144 |
model=model_id,
|
| 145 |
args=training_args,
|
|
|
|
| 147 |
peft_config=lora_config,
|
| 148 |
)
|
| 149 |
|
| 150 |
+
logger.info("Starting SFT: model=%s, examples=%d", model_id, len(dataset))
|
| 151 |
trainer.train()
|
| 152 |
trainer.save_model(str(output_dir))
|
| 153 |
+
logger.info("Saved SFT checkpoint to %s", output_dir)
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
def main() -> None:
|
| 157 |
+
parser = argparse.ArgumentParser(description="Parlay SFT training")
|
| 158 |
parser.add_argument("--data", default="data/episodes.jsonl")
|
| 159 |
+
parser.add_argument("--model", default=DEFAULT_MODEL)
|
| 160 |
+
parser.add_argument("--output", default=DEFAULT_OUTPUT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
parser.add_argument(
|
| 162 |
+
"--min-reward",
|
| 163 |
+
type=float,
|
| 164 |
+
default=-50.0,
|
| 165 |
+
help="Skip JSONL records with total reward below this (default: -50.0)",
|
| 166 |
)
|
| 167 |
args = parser.parse_args()
|
| 168 |
|
| 169 |
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
| 170 |
+
train_sft(Path(args.data), args.model, Path(args.output), min_reward=args.min_reward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
if __name__ == "__main__":
|