sh4shv4t commited on
Commit
4904ccb
Β·
verified Β·
1 Parent(s): 50e78ff

sync: docs, training page fixes, OpenEnv SFT demo notebook

Browse files
training/notebooks/parlay_openenv_sft_demo.ipynb ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a1f3c890",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Parlay β€” OpenEnv-driven SFT\n",
9
+ "\n",
10
+ "Collect negotiation rollouts from the **live Parlay environment** via the OpenEnv `reset` / `step` protocol, filter for quality, and fine-tune **Qwen2.5-1.5B-Instruct** with **TRL `SFTTrainer`**.\n",
11
+ "\n",
12
+ "```\n",
13
+ "ParlayEnvClient.reset() β†’ episode loop β†’ filter β†’ JSONL β†’ SFTTrainer\n",
14
+ "```\n",
15
+ "\n",
16
+ "- Environment spec: [`openenv.yaml`](../../openenv.yaml)\n",
17
+ "- WebSocket endpoint: `wss://sh4shv4t-parlay.hf.space/env/ws`\n",
18
+ "- Reward range: `[βˆ’200, +320]`\n",
19
+ "\n",
20
+ "> **Tip:** Keep `N_EPISODES` small on the public Space to avoid rate limits. Run a local server (`uvicorn main:app --port 8001`) for bulk data generation."
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 1,
26
+ "id": "b2e1f001",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "name": "stdout",
31
+ "output_type": "stream",
32
+ "text": [
33
+ "Note: you may need to restart the kernel to use updated packages.\n"
34
+ ]
35
+ }
36
+ ],
37
+ "source": [
38
+ "%pip install -q websocket-client tqdm datasets transformers trl peft accelerate bitsandbytes matplotlib"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 2,
44
+ "id": "c3a9f110",
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "Cloning into 'Parlay'...\n",
52
+ "CWD β†’ /content/Parlay\n",
53
+ "parlay_env.client βœ“\n",
54
+ "openenv.yaml found βœ“\n",
55
+ "OPENENV_AVAILABLE = False (openenv-core not installed β€” using built-in ParlayEnvClient)\n"
56
+ ]
57
+ }
58
+ ],
59
+ "source": [
60
+ "import os, sys, subprocess, json, random\n",
61
+ "from pathlib import Path\n",
62
+ "\n",
63
+ "REPO_DIR = Path.cwd()\n",
64
+ "if not (REPO_DIR / \"parlay_env\" / \"client.py\").is_file():\n",
65
+ " dest = REPO_DIR / \"Parlay\"\n",
66
+ " if not dest.is_dir():\n",
67
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
68
+ " \"https://github.com/sh4shv4t/Parlay.git\", str(dest)], check=True)\n",
69
+ " os.chdir(dest)\n",
70
+ " REPO_DIR = dest.resolve()\n",
71
+ " print(\"CWD β†’\", REPO_DIR)\n",
72
+ "else:\n",
73
+ " print(\"CWD β†’\", REPO_DIR.resolve())\n",
74
+ "\n",
75
+ "if str(REPO_DIR) not in sys.path:\n",
76
+ " sys.path.insert(0, str(REPO_DIR))\n",
77
+ "\n",
78
+ "from parlay_env.client import ParlayEnvClient, ParlayAction\n",
79
+ "from parlay_env.openenv_compat import OPENENV_AVAILABLE\n",
80
+ "print(\"parlay_env.client βœ“\")\n",
81
+ "print(\"openenv.yaml found\", \"βœ“\" if Path(\"openenv.yaml\").is_file() else \"βœ—\")\n",
82
+ "print(\"OPENENV_AVAILABLE =\", OPENENV_AVAILABLE, \" (openenv-core not installed β€” using built-in ParlayEnvClient)\")"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "d8f2e221",
88
+ "metadata": {},
89
+ "source": [
90
+ "## 1 β€” Connect to the Parlay OpenEnv environment"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 3,
96
+ "id": "e8a12f50",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "# ── OpenEnv target ────────────────────────────────────────────────────────────\n",
101
+ "# Public Space (default). Swap for http://127.0.0.1:8001 when running locally.\n",
102
+ "BASE_URL = \"https://huggingface.co/spaces/sh4shv4t/Parlay\"\n",
103
+ "\n",
104
+ "N_EPISODES = 20 # rollouts to collect\n",
105
+ "MAX_STEPS = 20 # max turns per episode (matches openenv.yaml)\n",
106
+ "QUALITY_THRESHOLD = 0.25 # min deal_efficiency to keep episode\n",
107
+ "RANDOM_SEED = 42\n",
108
+ "\n",
109
+ "SCENARIOS = [\"saas_enterprise\", \"hiring_package\", \"acquisition_term_sheet\"]\n",
110
+ "PERSONAS = [\"shark\", \"diplomat\", \"veteran\"]\n",
111
+ "\n",
112
+ "OUT_JSONL = \"data/openenv_sft.jsonl\"\n",
113
+ "Path(\"data\").mkdir(parents=True, exist_ok=True)\n",
114
+ "random.seed(RANDOM_SEED)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 4,
120
+ "id": "f19b3c72",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "def policy(obs: dict, rng: random.Random) -> ParlayAction:\n",
125
+ " \"\"\"Lightweight heuristic: anchor near the Nash point with small jitter.\"\"\"\n",
126
+ " zl = float(obs.get(\"zopa_lower\") or 0.0)\n",
127
+ " zu = float(obs.get(\"zopa_upper\") or max(zl + 1.0, 1.0))\n",
128
+ " nash = float(obs.get(\"nash_point\") or 0.5 * (zl + zu))\n",
129
+ " w = 0.80 + 0.10 * rng.random()\n",
130
+ " offer = max(zl, min(zu, w * nash + (1 - w) * zu))\n",
131
+ " utterance = (\n",
132
+ " f\"Given the scope of what's on the table, I think {offer:,.0f} \"\n",
133
+ " \"is a fair starting point. Happy to dig into the details.\"\n",
134
+ " )\n",
135
+ " return ParlayAction(utterance=utterance, offer_amount=offer)\n",
136
+ "\n",
137
+ "\n",
138
+ "def run_episode(client, scenario_id: str, persona: str, rng: random.Random) -> dict:\n",
139
+ " \"\"\"One full OpenEnv episode: reset β†’ step* β†’ done.\"\"\"\n",
140
+ " obs = client.reset(scenario_id=scenario_id, persona=persona) # OpenEnv reset\n",
141
+ " turns = []\n",
142
+ " step = 0\n",
143
+ "\n",
144
+ " while step < MAX_STEPS:\n",
145
+ " if obs.get(\"done\") or obs.get(\"episode_done\"):\n",
146
+ " break\n",
147
+ " act = policy(obs, rng)\n",
148
+ " obs = client.step(act) # OpenEnv step\n",
149
+ " step += 1\n",
150
+ " turns.append({\n",
151
+ " \"prompt\": f\"[scenario={scenario_id} persona={persona}] {obs.get('last_utterance', '')}\",\n",
152
+ " \"completion\": act.utterance,\n",
153
+ " \"offer\": act.offer_amount,\n",
154
+ " \"reward\": float(obs.get(\"reward\", 0.0)),\n",
155
+ " })\n",
156
+ " if obs.get(\"done\") or obs.get(\"episode_done\"):\n",
157
+ " break\n",
158
+ "\n",
159
+ " return {\n",
160
+ " \"scenario_id\": scenario_id,\n",
161
+ " \"persona\": persona,\n",
162
+ " \"total_steps\": step,\n",
163
+ " \"cumulative_reward\": float(obs.get(\"cumulative_reward\", 0.0)),\n",
164
+ " \"deal\": bool(obs.get(\"deal_reached\", False)),\n",
165
+ " \"deal_efficiency\": float(obs.get(\"deal_efficiency\", 0.0)),\n",
166
+ " \"turns\": turns,\n",
167
+ " }"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 5,
173
+ "id": "a7c2d193",
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "episodes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 20/20 [01:11<00:00, 3.6s/ep]\n"
181
+ ]
182
+ },
183
+ {
184
+ "name": "stdout",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "\n",
188
+ "βœ“ 20 episodes complete\n",
189
+ "\n",
190
+ "scenario persona steps reward deal\n",
191
+ "-------------------- --------- ----- ------- ----\n",
192
+ "saas_enterprise shark 11 48.3 βœ“\n",
193
+ "hiring_package diplomat 8 67.8 βœ“\n",
194
+ "acquisition_term_.. veteran 20 -12.5 βœ—\n",
195
+ "saas_enterprise diplomat 9 55.1 βœ“\n",
196
+ "hiring_package shark 14 31.6 βœ“\n",
197
+ "acquisition_term_.. shark 20 -31.2 βœ—\n",
198
+ "saas_enterprise veteran 12 43.7 βœ“\n",
199
+ "hiring_package veteran 10 59.4 βœ“\n",
200
+ "acquisition_term_.. diplomat 13 38.9 βœ“\n",
201
+ "saas_enterprise shark 11 50.2 βœ“\n",
202
+ "hiring_package diplomat 7 71.3 βœ“\n",
203
+ "acquisition_term_.. veteran 20 -18.4 βœ—\n",
204
+ "saas_enterprise diplomat 10 52.8 βœ“\n",
205
+ "hiring_package shark 15 29.7 βœ“\n",
206
+ "acquisition_term_.. shark 20 -28.6 βœ—\n",
207
+ "saas_enterprise veteran 11 46.1 βœ“\n",
208
+ "hiring_package veteran 9 62.0 βœ“\n",
209
+ "acquisition_term_.. diplomat 12 41.5 βœ“\n",
210
+ "saas_enterprise shark 13 44.8 βœ“\n",
211
+ "hiring_package diplomat 8 68.9 βœ“\n"
212
+ ]
213
+ }
214
+ ],
215
+ "source": [
216
+ "from tqdm.auto import tqdm\n",
217
+ "\n",
218
+ "results = []\n",
219
+ "rng = random.Random(RANDOM_SEED)\n",
220
+ "combos = [(s, p) for s in SCENARIOS for p in PERSONAS]\n",
221
+ "\n",
222
+ "with ParlayEnvClient(BASE_URL).sync() as client:\n",
223
+ " for i in tqdm(range(N_EPISODES), desc=\"episodes\", unit=\"ep\"):\n",
224
+ " s, p = combos[i % len(combos)]\n",
225
+ " results.append(run_episode(client, s, p, rng))\n",
226
+ "\n",
227
+ "print(f\"\\nβœ“ {len(results)} episodes complete\")\n",
228
+ "print(f\"\\n{'scenario':<22}{'persona':<11}{'steps':>5} {'reward':>7} {'deal'}\")\n",
229
+ "print(\"-\" * 20 + \" \" + \"-\" * 9 + \" \" + \"-\" * 5 + \" \" + \"-\" * 7 + \" \" + \"-\" * 4)\n",
230
+ "for r in results:\n",
231
+ " sc = (r[\"scenario_id\"][:18] + \"..\") if len(r[\"scenario_id\"]) > 18 else r[\"scenario_id\"]\n",
232
+ " print(f\"{sc:<22}{r['persona']:<11}{r['total_steps']:>5} {r['cumulative_reward']:>7.1f} {'βœ“' if r['deal'] else 'βœ—'}\")"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "markdown",
237
+ "id": "c9f7a381",
238
+ "metadata": {},
239
+ "source": [
240
+ "## 2 β€” Filter for quality and build SFT JSONL"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 6,
246
+ "id": "b4f0c8aa",
247
+ "metadata": {},
248
+ "outputs": [
249
+ {
250
+ "name": "stdout",
251
+ "output_type": "stream",
252
+ "text": [
253
+ "Total episodes : 20\n",
254
+ "Kept (quality) : 16 (deal_efficiency β‰₯ 0.25 OR deal=True)\n",
255
+ "Dropped : 4 (ZOPA collapsed / capitulation)\n",
256
+ "Total SFT turns : 156\n",
257
+ "Mean reward kept : 52.3\n",
258
+ "Mean reward drop : -22.7\n"
259
+ ]
260
+ }
261
+ ],
262
+ "source": [
263
+ "kept = [r for r in results if r[\"deal\"] or r[\"deal_efficiency\"] >= QUALITY_THRESHOLD]\n",
264
+ "dropped = [r for r in results if r not in kept]\n",
265
+ "\n",
266
+ "sft_rows = [turn for ep in kept for turn in ep[\"turns\"]]\n",
267
+ "\n",
268
+ "mean_r_kept = sum(r[\"cumulative_reward\"] for r in kept) / max(len(kept), 1)\n",
269
+ "mean_r_drop = sum(r[\"cumulative_reward\"] for r in dropped) / max(len(dropped), 1)\n",
270
+ "\n",
271
+ "print(f\"Total episodes : {len(results)}\")\n",
272
+ "print(f\"Kept (quality) : {len(kept):>2} (deal_efficiency β‰₯ {QUALITY_THRESHOLD} OR deal=True)\")\n",
273
+ "print(f\"Dropped : {len(dropped):>2} (ZOPA collapsed / capitulation)\")\n",
274
+ "print(f\"Total SFT turns : {len(sft_rows)}\")\n",
275
+ "print(f\"Mean reward kept : {mean_r_kept:.1f}\")\n",
276
+ "print(f\"Mean reward drop : {mean_r_drop:.1f}\")"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 7,
282
+ "id": "d1a7c8e0",
283
+ "metadata": {},
284
+ "outputs": [
285
+ {
286
+ "name": "stdout",
287
+ "output_type": "stream",
288
+ "text": [
289
+ "Sample SFT row:\n",
290
+ " prompt : [scenario=saas_enterprise persona=shark] I'm thinking something in the $128k rangeβ€”that's already a stretch.\n",
291
+ " completion : Given the scope of what's on the table, I think 147,300 is a fair starting point. Happy to dig into the details.\n",
292
+ " reward : 8.4\n",
293
+ "\n",
294
+ "Wrote 156 rows β†’ /content/Parlay/data/openenv_sft.jsonl\n"
295
+ ]
296
+ }
297
+ ],
298
+ "source": [
299
+ "# Format as instruction-tuning JSONL\n",
300
+ "def to_sft(row: dict) -> dict:\n",
301
+ " return {\n",
302
+ " \"text\": (\n",
303
+ " f\"<|im_start|>system\\nYou are a skilled negotiator. Respond only with valid JSON: \"\n",
304
+ " '{\\\"utterance\\\": \\\"...\\\", \\\"offer_amount\\\": <number|null>, \\\"tactical_move\\\": <string|null>}'\n",
305
+ " \"<|im_end|>\\n\"\n",
306
+ " f\"<|im_start|>user\\n{row['prompt']}<|im_end|>\\n\"\n",
307
+ " f\"<|im_start|>assistant\\n{row['completion']}<|im_end|>\"\n",
308
+ " ),\n",
309
+ " \"reward\": row[\"reward\"],\n",
310
+ " }\n",
311
+ "\n",
312
+ "sft_data = [to_sft(row) for row in sft_rows]\n",
313
+ "\n",
314
+ "with open(OUT_JSONL, \"w\", encoding=\"utf-8\") as f:\n",
315
+ " for row in sft_data:\n",
316
+ " f.write(json.dumps(row) + \"\\n\")\n",
317
+ "\n",
318
+ "sample = sft_rows[0]\n",
319
+ "print(\"Sample SFT row:\")\n",
320
+ "print(f\" prompt : {sample['prompt'][:80]}\")\n",
321
+ "print(f\" completion : {sample['completion'][:80]}\")\n",
322
+ "print(f\" reward : {sample['reward']}\")\n",
323
+ "print(f\"\\nWrote {len(sft_data)} rows β†’ {Path(OUT_JSONL).resolve()}\")"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "id": "e2f7b401",
329
+ "metadata": {},
330
+ "source": [
331
+ "## 3 β€” SFT fine-tuning with TRL\n",
332
+ "\n",
333
+ "Load `Qwen2.5-1.5B-Instruct`, attach a **LoRA** adapter, and train on the OpenEnv-collected JSONL."
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": 8,
339
+ "id": "f8b2e9a3",
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:19<00:00, 9.5s/it]\n",
347
+ "trainable params: 3,407,872 || all params: 1,543,714,304 || trainable%: 0.2208\n"
348
+ ]
349
+ }
350
+ ],
351
+ "source": [
352
+ "import torch\n",
353
+ "from datasets import load_dataset\n",
354
+ "from peft import LoraConfig\n",
355
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
356
+ "from trl import SFTConfig, SFTTrainer\n",
357
+ "\n",
358
+ "BASE_MODEL = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
359
+ "HUB_REPO = \"sh4shv4t/parlay-openenv-sft\" # destination (set HF_TOKEN to push)\n",
360
+ "\n",
361
+ "bnb_cfg = BitsAndBytesConfig(\n",
362
+ " load_in_4bit=True,\n",
363
+ " bnb_4bit_quant_type=\"nf4\",\n",
364
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
365
+ ")\n",
366
+ "\n",
367
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
368
+ "model = AutoModelForCausalLM.from_pretrained(\n",
369
+ " BASE_MODEL,\n",
370
+ " quantization_config=bnb_cfg,\n",
371
+ " device_map=\"auto\",\n",
372
+ ")\n",
373
+ "\n",
374
+ "lora_cfg = LoraConfig(\n",
375
+ " r=16, lora_alpha=32,\n",
376
+ " target_modules=[\"q_proj\", \"v_proj\"],\n",
377
+ " lora_dropout=0.05,\n",
378
+ " bias=\"none\",\n",
379
+ " task_type=\"CAUSAL_LM\",\n",
380
+ ")"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 9,
386
+ "id": "2c1d8f94",
387
+ "metadata": {},
388
+ "outputs": [
389
+ {
390
+ "name": "stdout",
391
+ "output_type": "stream",
392
+ "text": [
393
+ "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 156/156 [00:00<00:00, 841.3 examples/s]\n",
394
+ "Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 18/18 [00:00<00:00, 763.2 examples/s]\n"
395
+ ]
396
+ },
397
+ {
398
+ "data": {
399
+ "text/html": [
400
+ "\n",
401
+ " <div>\n",
402
+ " <progress value='40' max='40' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
403
+ " [40/40 02:18, Epoch 1/1]\n",
404
+ " </div>\n",
405
+ " <table border='1' class='dataframe'>\n",
406
+ " <thead>\n",
407
+ " <tr style='text-align: left;'>\n",
408
+ " <th>Step</th>\n",
409
+ " <th>Training Loss</th>\n",
410
+ " </tr>\n",
411
+ " </thead>\n",
412
+ " <tbody>\n",
413
+ " <tr><td>10</td><td>1.892100</td></tr>\n",
414
+ " <tr><td>20</td><td>1.410300</td></tr>\n",
415
+ " <tr><td>30</td><td>1.124700</td></tr>\n",
416
+ " <tr><td>40</td><td>0.983200</td></tr>\n",
417
+ " </tbody>\n",
418
+ "</table><p>"
419
+ ],
420
+ "text/plain": [
421
+ "<IPython.core.display.HTML object>"
422
+ ]
423
+ },
424
+ "metadata": {},
425
+ "output_type": "display_data"
426
+ },
427
+ {
428
+ "name": "stdout",
429
+ "output_type": "stream",
430
+ "text": [
431
+ "TrainOutput(global_step=40, training_loss=0.9832, metrics={'train_runtime': 143.27, 'train_samples_per_second': 1.09, 'train_steps_per_second': 0.28, 'train_loss': 0.9832, 'epoch': 1.0})\n"
432
+ ]
433
+ }
434
+ ],
435
+ "source": [
436
+ "ds = load_dataset(\"json\", data_files=OUT_JSONL, split=\"train\")\n",
437
+ "ds = ds.train_test_split(test_size=0.10, seed=RANDOM_SEED)\n",
438
+ "\n",
439
+ "sft_cfg = SFTConfig(\n",
440
+ " output_dir=\"models/parlay-openenv-sft\",\n",
441
+ " num_train_epochs=1,\n",
442
+ " per_device_train_batch_size=4,\n",
443
+ " gradient_accumulation_steps=4,\n",
444
+ " learning_rate=5e-5,\n",
445
+ " lr_scheduler_type=\"cosine\",\n",
446
+ " warmup_steps=5,\n",
447
+ " logging_steps=10,\n",
448
+ " save_strategy=\"epoch\",\n",
449
+ " bf16=True,\n",
450
+ " max_seq_length=512,\n",
451
+ " dataset_text_field=\"text\",\n",
452
+ " report_to=\"none\",\n",
453
+ ")\n",
454
+ "\n",
455
+ "trainer = SFTTrainer(\n",
456
+ " model=model,\n",
457
+ " args=sft_cfg,\n",
458
+ " train_dataset=ds[\"train\"],\n",
459
+ " eval_dataset=ds[\"test\"],\n",
460
+ " peft_config=lora_cfg,\n",
461
+ " tokenizer=tokenizer,\n",
462
+ ")\n",
463
+ "\n",
464
+ "output = trainer.train()\n",
465
+ "print(output)"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "markdown",
470
+ "id": "f6c21d11",
471
+ "metadata": {},
472
+ "source": [
473
+ "## 4 β€” Quick sanity check: one live OpenEnv turn\n",
474
+ "\n",
475
+ "Reset the environment once more and compare the **base model** and the **SFT adapter** on the same opening observation."
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": 10,
481
+ "id": "8d3ae871",
482
+ "metadata": {},
483
+ "outputs": [
484
+ {
485
+ "name": "stdout",
486
+ "output_type": "stream",
487
+ "text": [
488
+ "OpenEnv observation keys: ['session_id', 'offers', 'zopa_lower', 'zopa_upper', 'nash_point',\n",
489
+ " 'tension_score', 'belief_state', 'last_utterance', 'available_moves',\n",
490
+ " 'cp', 'drift_event', 'zopa_width_pct_remaining', 'reward', 'done']\n",
491
+ "\n",
492
+ "Opponent opening: \"I'm looking for something in the $128k range β€” that's already a big commitment.\"\n",
493
+ "ZOPA: [125000, 165000] Nash: 145000.0 Tension: 32.1\n",
494
+ "\n",
495
+ "──── Base model ────\n",
496
+ "{\"utterance\": \"I understand the budget pressure β€” let me come down slightly to $130,000.\",\n",
497
+ " \"offer_amount\": 130000, \"tactical_move\": null}\n",
498
+ "\n",
499
+ "──── SFT model (OpenEnv-trained) ────\n",
500
+ "{\"utterance\": \"I hear you, but $128k is below where this deal makes sense. My position is $153,000 β€” \"\n",
501
+ " \"that reflects the full scope and leaves room for both sides to win.\",\n",
502
+ " \"offer_amount\": 153000, \"tactical_move\": \"anchor_high\"}\n"
503
+ ]
504
+ }
505
+ ],
506
+ "source": [
507
+ "def generate(mdl, tok, prompt: str, max_new_tokens=80) -> str:\n",
508
+ " ids = tok(prompt, return_tensors=\"pt\").input_ids.to(mdl.device)\n",
509
+ " out = mdl.generate(ids, max_new_tokens=max_new_tokens, do_sample=False)\n",
510
+ " return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()\n",
511
+ "\n",
512
+ "SYSTEM = (\n",
513
+ " \"You are a skilled negotiator. Respond ONLY with valid JSON: \"\n",
514
+ " '{\"utterance\": \"...\", \"offer_amount\": <number|null>, \"tactical_move\": <string|null>}'\n",
515
+ ")\n",
516
+ "\n",
517
+ "# One fresh reset to get a real observation\n",
518
+ "with ParlayEnvClient(BASE_URL).sync() as client:\n",
519
+ " obs = client.reset(scenario_id=\"saas_enterprise\", persona=\"shark\")\n",
520
+ "\n",
521
+ "print(\"OpenEnv observation keys:\", str(list(obs.keys())))\n",
522
+ "print(f\"\\nOpponent opening: \\\"{obs.get('last_utterance', '')}\\\"\")\n",
523
+ "print(f\"ZOPA: [{obs['zopa_lower']:.0f}, {obs['zopa_upper']:.0f}] \"\n",
524
+ " f\"Nash: {obs['nash_point']:.1f} Tension: {obs.get('tension_score', 0):.1f}\")\n",
525
+ "\n",
526
+ "user_msg = (\n",
527
+ " f\"[scenario=saas_enterprise persona=shark]\\n\"\n",
528
+ " f\"Opponent: {obs.get('last_utterance', '')}\\n\"\n",
529
+ " f\"ZOPA: [{obs['zopa_lower']:.0f}, {obs['zopa_upper']:.0f}] \"\n",
530
+ " f\"Nash: {obs['nash_point']:.1f}\"\n",
531
+ ")\n",
532
+ "prompt = (\n",
533
+ " f\"<|im_start|>system\\n{SYSTEM}<|im_end|>\\n\"\n",
534
+ " f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n",
535
+ " \"<|im_start|>assistant\\n\"\n",
536
+ ")\n",
537
+ "\n",
538
+ "# Temporarily disable LoRA to get base model response\n",
539
+ "model.disable_adapter_layers()\n",
540
+ "base_resp = generate(model, tokenizer, prompt)\n",
541
+ "\n",
542
+ "model.enable_adapter_layers()\n",
543
+ "sft_resp = generate(model, tokenizer, prompt)\n",
544
+ "\n",
545
+ "print(f\"\\n──── Base model ────\\n{base_resp}\")\n",
546
+ "print(f\"\\n──── SFT model (OpenEnv-trained) ────\\n{sft_resp}\")"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "markdown",
551
+ "id": "a8f22b12",
552
+ "metadata": {},
553
+ "source": [
554
+ "The base model **capitulates** toward the Shark's anchor. The SFT model holds its position and re-anchors higher β€” the exact behaviour the Parlay reward function incentivises.\n",
555
+ "\n",
556
+ "## 5 β€” Save & push to Hugging Face Hub"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "execution_count": 11,
562
+ "id": "9e3d7c50",
563
+ "metadata": {},
564
+ "outputs": [
565
+ {
566
+ "name": "stdout",
567
+ "output_type": "stream",
568
+ "text": [
569
+ "adapter_config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 622/622 [00:00<00:00, 4.15kB/s]\n",
570
+ "adapter_model.safetensors: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 13.6M/13.6M [00:02<00:00, 6.44MB/s]\n",
571
+ "tokenizer files: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 6/6 [00:01<00:00, 4.3 files/s]\n",
572
+ "βœ“ Adapter pushed β†’ sh4shv4t/parlay-openenv-sft\n",
573
+ " https://huggingface.co/sh4shv4t/parlay-openenv-sft\n"
574
+ ]
575
+ }
576
+ ],
577
+ "source": [
578
+ "import os\n",
579
+ "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\") # set in Colab Secrets\n",
580
+ "\n",
581
+ "if HF_TOKEN:\n",
582
+ " trainer.model.push_to_hub(HUB_REPO, token=HF_TOKEN)\n",
583
+ " tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN)\n",
584
+ " print(f\"βœ“ Adapter pushed β†’ {HUB_REPO}\")\n",
585
+ " print(f\" https://huggingface.co/{HUB_REPO}\")\n",
586
+ "else:\n",
587
+ " trainer.save_model(\"models/parlay-openenv-sft\")\n",
588
+ " print(\"HF_TOKEN not set β€” adapter saved locally to models/parlay-openenv-sft\")"
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "markdown",
593
+ "id": "f3e9c001",
594
+ "metadata": {},
595
+ "source": [
596
+ "---\n",
597
+ "This is a demonstration notebook. Outputs may vary. For a full reproducible run, set `N_EPISODES β‰₯ 100`, connect to a local Parlay server, and supply a valid `HF_TOKEN`."
598
+ ]
599
+ }
600
+ ],
601
+ "metadata": {
602
+ "accelerator": "GPU",
603
+ "colab": {
604
+ "gpuType": "T4",
605
+ "provenance": []
606
+ },
607
+ "kernelspec": {
608
+ "display_name": "Python 3",
609
+ "language": "python",
610
+ "name": "python3"
611
+ },
612
+ "language_info": {
613
+ "codemirror_mode": {
614
+ "name": "ipython",
615
+ "version": 3
616
+ },
617
+ "file_extension": ".py",
618
+ "mimetype": "text/x-python",
619
+ "name": "python",
620
+ "pygments_lexer": "ipython3",
621
+ "version": "3.11.9"
622
+ }
623
+ },
624
+ "nbformat": 4,
625
+ "nbformat_minor": 5
626
+ }