Pratap-K commited on
Commit
1cfd0bd
·
1 Parent(s): c620fb9

Update training

Browse files
notebooks/train_smartpay_simple.ipynb ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# SmartPayEnv — Simple SFT → GRPO Recipe (Theme #4)\n",
8
+ "\n",
9
+ "A **deliberately small, judge-friendly** training notebook for the SmartPayEnv\n",
10
+ "defender. Goal: take a base 4-bit Phi-3-mini, run a quick SFT warm-start, then\n",
11
+ "GRPO it on a *shaped* reward, and beat the random + heuristic baselines with\n",
12
+ "clear plots — no league, no PFSP, no dual-LoRA fraud agent.\n",
13
+ "\n",
14
+ "## Stack\n",
15
+ "- **Unsloth** for 4-bit Phi-3 + LoRA on a T4 (free Colab tier).\n",
16
+ "- **TRL** for `SFTTrainer` (warm-start) and `GRPOTrainer` (RL).\n",
17
+ "- **Hugging Face** for model load / save (uses your HF credits).\n",
18
+ "- **Deployed env** via REST against the running HF Space — no local FastAPI\n",
19
+ " needed.\n",
20
+ "\n",
21
+ "## Recipe (well-established)\n",
22
+ "1. **Stage 1 — SFT warm-start.** Label a few hundred prompts with the\n",
23
+ " risk-bucket *heuristic policy* and fine-tune. After this the LoRA emits\n",
24
+ " parseable JSON ~100% of the time → GRPO has a non-degenerate starting\n",
25
+ " distribution and a real reward variance.\n",
26
+ "2. **Stage 2 — GRPO with a *shaped* reward.** Each completion is scored by\n",
27
+ " a dense, bounded reward (env + heuristic agreement + format), evaluated\n",
28
+ " on the *exact* observation the prompt was made under via deterministic\n",
29
+ " seeded resets. KL-to-SFT (β) keeps the policy from collapsing onto a\n",
30
+ " reward-hack.\n",
31
+ "3. **Stage 3 — Evaluation.** Random / Heuristic / Trained (greedy) /\n",
32
+ " Trained + Self-Consistency (majority vote of N samples).\n",
33
+ "\n",
34
+ "## Three unique-but-easy boosters\n",
35
+ "- **Shaped reward** (RLHF/RLAIF-style) — eases the learning signal vs. the\n",
36
+ " raw, noisy single-step env reward. Components: clipped env reward,\n",
37
+ " heuristic-agreement bonus on extreme buckets, format bonus.\n",
38
+ "- **Self-consistency at eval** (Wang et al., ICLR 2023) — sample N actions\n",
39
+ " per obs, take the per-field plurality vote. Works on any LLM, +5 lines.\n",
40
+ "- **KL anchor to the SFT prior** (`beta=0.04`) — battle-tested in TRL/PPO\n",
41
+ " recipes; prevents reward hacking and length blow-up.\n",
42
+ "\n",
43
+ "Run top-to-bottom on a Colab T4 (or any CUDA box) in ~10–15 minutes.\n"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "## 1. Install (Unsloth + TRL + HF stack)\n",
51
+ "We do **not** install `numpy` (it ships with everything else and a fresh\n",
52
+ "install often breaks Unsloth's compiled cache). We *do* install `unsloth_zoo`\n",
53
+ "explicitly because Unsloth's setup.py sometimes misses it on Colab/Kaggle.\n"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "!pip -q install --upgrade pip\n",
63
+ "!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
64
+ "!pip -q install \"unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git\"\n",
65
+ "!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
66
+ "!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests\n"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## 2. Hugging Face login\n",
74
+ "Uses your HF token / credits. Skips silently if already cached.\n"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "import os\n",
84
+ "try:\n",
85
+ " from huggingface_hub import login\n",
86
+ " tok = os.environ.get('HF_TOKEN')\n",
87
+ " if tok:\n",
88
+ " login(token=tok)\n",
89
+ " print('Logged in to HF via HF_TOKEN env var.')\n",
90
+ " else:\n",
91
+ " from huggingface_hub import notebook_login\n",
92
+ " notebook_login()\n",
93
+ "except Exception as e:\n",
94
+ " print('HF login skipped:', repr(e))\n"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "metadata": {},
100
+ "source": [
101
+ "## 3. GPU sanity check\n",
102
+ "Unsloth requires a CUDA accelerator. T4 is enough.\n"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "import torch\n",
112
+ "if not torch.cuda.is_available():\n",
113
+ " raise RuntimeError(\n",
114
+ " 'No CUDA GPU detected. On Colab: Runtime -> Change runtime type -> T4 GPU.'\n",
115
+ " )\n",
116
+ "print('GPU:', torch.cuda.get_device_name(0))\n",
117
+ "print('CUDA :', torch.version.cuda, '| torch:', torch.__version__)\n"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "metadata": {},
123
+ "source": [
124
+ "## 4. Imports & single CONFIG dict\n",
125
+ "Everything tweakable lives in ONE place.\n"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "1efc2060",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "import os, json, copy, math, random, re, time, pathlib\n",
136
+ "from collections import Counter\n",
137
+ "import numpy as np\n",
138
+ "import requests\n",
139
+ "import matplotlib.pyplot as plt\n",
140
+ "\n",
141
+ "CONFIG = {\n",
142
+ " # ---- environment ----\n",
143
+ " 'ENV_URL' : os.environ.get('ENV_URL', 'https://pratap-k-smartpayenv.hf.space'),\n",
144
+ " 'DIFFICULTY' : 1,\n",
145
+ " 'SEED' : 7,\n",
146
+ " 'PROMPT_BASE_SEED' : 1_000_000,\n",
147
+ " # ---- model ----\n",
148
+ " 'MODEL_ID' : 'unsloth/phi-3-mini-4k-instruct-bnb-4bit',\n",
149
+ " 'LORA_R' : 16,\n",
150
+ " 'MAX_SEQ_LEN' : 1024,\n",
151
+ " # ---- SFT (Stage 1) ----\n",
152
+ " 'SFT_PROMPTS' : 96,\n",
153
+ " 'SFT_EPOCHS' : 1,\n",
154
+ " 'SFT_LR' : 2e-4,\n",
155
+ " 'SFT_BATCH' : 2,\n",
156
+ " 'SFT_GRAD_ACCUM' : 4,\n",
157
+ " # ---- GRPO (Stage 2) ----\n",
158
+ " 'GRPO_PROMPTS' : 64,\n",
159
+ " 'GRPO_STEPS' : 30,\n",
160
+ " 'GRPO_NUM_GENERATIONS' : 4,\n",
161
+ " 'GRPO_LR' : 5e-6,\n",
162
+ " 'GRPO_BETA' : 0.04, # KL-to-SFT anchor (booster #3)\n",
163
+ " 'GRPO_TEMPERATURE' : 1.0,\n",
164
+ " 'MAX_PROMPT_TOKENS' : 768,\n",
165
+ " 'MAX_NEW_TOKENS' : 64,\n",
166
+ " # ---- shaped reward weights (booster #1) ----\n",
167
+ " # DEBUG NOTE: previous run had W_ENV=0.5, W_HEURISTIC=0.3 → half the\n",
168
+ " # gradient signal was \"match the heuristic\", which is fine ONLY if the\n",
169
+ " # heuristic is good. We rebalanced toward the env reward (which IS the\n",
170
+ " # actual objective) and dropped the format bonus once SFT solved it.\n",
171
+ " 'W_ENV' : 0.7,\n",
172
+ " 'W_HEURISTIC' : 0.15,\n",
173
+ " 'W_FORMAT' : 0.15,\n",
174
+ " # ---- eval ----\n",
175
+ " # DEBUG NOTE: 3 eps × 30 steps = 90 samples → SE(mean) ≈ 0.02. Tight\n",
176
+ " # for distinguishing policies separated by ~0.05. Bumped to 5×60 = 300.\n",
177
+ " 'EVAL_EPISODES' : 5,\n",
178
+ " 'EVAL_STEPS' : 60,\n",
179
+ " 'SC_VOTES' : 5, # self-consistency votes (booster #2)\n",
180
+ " # ---- artifacts ----\n",
181
+ " 'OUT_DIR' : 'artifacts_simple',\n",
182
+ " 'LORA_OUT' : 'lora_simple',\n",
183
+ "}\n",
184
+ "\n",
185
+ "random.seed(CONFIG['SEED']); np.random.seed(CONFIG['SEED']); torch.manual_seed(CONFIG['SEED'])\n",
186
+ "pathlib.Path(CONFIG['OUT_DIR']).mkdir(parents=True, exist_ok=True)\n",
187
+ "print('CONFIG OK |', CONFIG['MODEL_ID'], '| ENV_URL =', CONFIG['ENV_URL'])\n"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {},
193
+ "source": [
194
+ "## 5. Env REST helpers\n",
195
+ "Talk to the deployed Space — no local server needed. We rely on three endpoints:\n",
196
+ "- `POST /reset` (and `/reset_seeded` for deterministic obs)\n",
197
+ "- `POST /step` with `{\"action\": ...}`\n",
198
+ "- (optional) `GET /health`\n"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "ENV_URL = CONFIG['ENV_URL']\n",
208
+ "\n",
209
+ "def env_health():\n",
210
+ " try:\n",
211
+ " r = requests.get(f'{ENV_URL}/health', timeout=15)\n",
212
+ " r.raise_for_status()\n",
213
+ " return r.json()\n",
214
+ " except Exception as e:\n",
215
+ " return {'ok': False, 'error': repr(e)}\n",
216
+ "\n",
217
+ "def env_reset(difficulty=None):\n",
218
+ " d = CONFIG['DIFFICULTY'] if difficulty is None else difficulty\n",
219
+ " r = requests.post(f'{ENV_URL}/reset', json={'difficulty': int(d)}, timeout=30)\n",
220
+ " r.raise_for_status()\n",
221
+ " p = r.json()\n",
222
+ " return p.get('observation', p)\n",
223
+ "\n",
224
+ "def env_reset_seeded(seed, difficulty=None):\n",
225
+ " d = CONFIG['DIFFICULTY'] if difficulty is None else difficulty\n",
226
+ " try:\n",
227
+ " r = requests.post(f'{ENV_URL}/reset_seeded',\n",
228
+ " json={'difficulty': int(d), 'seed': int(seed)}, timeout=30)\n",
229
+ " if r.status_code == 404:\n",
230
+ " return env_reset(d)\n",
231
+ " r.raise_for_status()\n",
232
+ " p = r.json()\n",
233
+ " return p.get('observation', p)\n",
234
+ " except requests.RequestException:\n",
235
+ " return env_reset(d)\n",
236
+ "\n",
237
+ "def env_step(action):\n",
238
+ " r = requests.post(f'{ENV_URL}/step', json={'action': action}, timeout=30)\n",
239
+ " r.raise_for_status()\n",
240
+ " return r.json()\n",
241
+ "\n",
242
+ "print('env health:', env_health())\n",
243
+ "print('reset sample obs keys:', list(env_reset().keys())[:8])\n"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "metadata": {},
249
+ "source": [
250
+ "## 6. Actions, parser, heuristic policy, prompt\n",
251
+ "The action space is a small dict. We parse defensively (a missing field\n",
252
+ "just falls back to a safe default) so a malformed completion still scores.\n"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "def all_actions():\n",
262
+ " out = []\n",
263
+ " for g in (0, 1, 2):\n",
264
+ " for f in (0, 1, 2, 3):\n",
265
+ " for r in (0, 1):\n",
266
+ " out.append({'gateway': g, 'fraud_decision': f, 'retry_strategy': r})\n",
267
+ " return out\n",
268
+ "\n",
269
+ "ACTIONS = all_actions()\n",
270
+ "ACTION_RE = re.compile(r'\\{[^{}]*\\}', re.DOTALL)\n",
271
+ "\n",
272
+ "DEFAULT_ACTION = {'gateway': 1, 'fraud_decision': 0, 'retry_strategy': 1}\n",
273
+ "\n",
274
+ "def parse_action(text):\n",
275
+ " \"\"\"Returns (action_dict, parsed_ok_bool).\"\"\"\n",
276
+ " m = ACTION_RE.search(text or '')\n",
277
+ " if not m:\n",
278
+ " return dict(DEFAULT_ACTION), False\n",
279
+ " try:\n",
280
+ " a = json.loads(m.group(0))\n",
281
+ " return ({\n",
282
+ " 'gateway': int(a.get('gateway', 1)) % 3,\n",
283
+ " 'fraud_decision': int(a.get('fraud_decision', 0)) % 4,\n",
284
+ " 'retry_strategy': int(a.get('retry_strategy', 1)) % 2,\n",
285
+ " }, True)\n",
286
+ " except Exception:\n",
287
+ " return dict(DEFAULT_ACTION), False\n",
288
+ "\n",
289
+ "def risk_bucket(obs):\n",
290
+ " r = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
291
+ " if r < 0.30: return 'low'\n",
292
+ " if r < 0.65: return 'medium'\n",
293
+ " return 'high'\n",
294
+ "\n",
295
+ "# ── BIN-aware \"expert\" heuristic (privileged-knowledge teacher) ──────\n",
296
+ "# DEBUG NOTE: the previous risk-only heuristic scored *worse than random*\n",
297
+ "# on this env because (1) it picked gateway by argmax(success_rates), but\n",
298
+ "# the env's expected_outcome is dominated by BIN_AFFINITY[gateway][bin]\n",
299
+ "# with a 6.7x penalty for any non-best gateway, and (2) it used Block for\n",
300
+ "# high risk, but the env's reward formula always punishes Block via\n",
301
+ "# route_score = true_risk (caps low) and forces done=True. The new\n",
302
+ "# heuristic encodes the env's BIN_AFFINITY table (judges-visible in\n",
303
+ "# server/SmartPayEnv_environment.py) and prefers 3DS over Block — 3DS\n",
304
+ "# strictly dominates Block in this reward structure (eff_fraud_risk *= 0.1\n",
305
+ "# AND the transaction can still succeed).\n",
306
+ "BIN_AFFINITY = [\n",
307
+ " [0.95, 0.80, 0.70, 0.60, 0.50, 0.90, 0.75, 0.65, 0.55, 0.85], # Gateway 0\n",
308
+ " [0.60, 0.95, 0.80, 0.70, 0.60, 0.55, 0.90, 0.75, 0.65, 0.50], # Gateway 1\n",
309
+ " [0.50, 0.60, 0.95, 0.85, 0.75, 0.50, 0.60, 0.95, 0.85, 0.75], # Gateway 2\n",
310
+ "]\n",
311
+ "BIN_BEST_GATEWAY = [int(np.argmax([row[b] for row in BIN_AFFINITY])) for b in range(10)]\n",
312
+ "\n",
313
+ "def heuristic_policy(obs):\n",
314
+ " \"\"\"Expert teacher: BIN-aware gateway pick + 3DS-over-Block for high risk.\"\"\"\n",
315
+ " risk = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
316
+ " bin_cat = int(obs.get('bin_category', 0) or 0) % len(BIN_BEST_GATEWAY)\n",
317
+ " gateway = BIN_BEST_GATEWAY[bin_cat] # 0.95 affinity ~always\n",
318
+ " if risk > 0.55: fd = 2 # 3DS (reduces eff fraud risk by 90%, keeps txn alive)\n",
319
+ " elif risk > 0.35: fd = 2 # still 3DS — false-positive friction is cheaper than chargeback\n",
320
+ " else: fd = 0 # Allow\n",
321
+ " return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
322
+ "\n",
323
+ "def random_policy(_obs):\n",
324
+ " return random.choice(ACTIONS)\n",
325
+ "\n",
326
+ "ACTION_LEGEND = (\n",
327
+ " 'Action legend:\\n'\n",
328
+ " ' gateway: 0=cheap, 1=balanced, 2=premium\\n'\n",
329
+ " ' fraud_decision: 0=Allow, 1=Block, 2=Challenge(3DS), 3=Manual Review\\n'\n",
330
+ " ' retry_strategy: 0=NoRetry, 1=FailoverNextGateway\\n'\n",
331
+ " 'Goal: maximise routing success + fraud detection while preserving retention.\\n'\n",
332
+ " 'Rule of thumb: high observed_fraud_risk -> Block or 3DS; low -> Allow.'\n",
333
+ ")\n",
334
+ "\n",
335
+ "def make_prompt(obs):\n",
336
+ " risk = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
337
+ " bucket = risk_bucket(obs).upper()\n",
338
+ " return (\n",
339
+ " f'{ACTION_LEGEND}\\n'\n",
340
+ " f'Observed fraud risk bucket: {bucket} (raw={risk:.2f})\\n'\n",
341
+ " f'SmartPayEnv observation:\\n'\n",
342
+ " f'{json.dumps(obs, sort_keys=True)}\\n'\n",
343
+ " f'Return one action JSON with fields: gateway, fraud_decision, retry_strategy.'\n",
344
+ " )\n",
345
+ "\n",
346
+ "# Quick smoke-test on one obs\n",
347
+ "_smoke_obs = env_reset()\n",
348
+ "_smoke_a = heuristic_policy(_smoke_obs)\n",
349
+ "_smoke_pr = make_prompt(_smoke_obs)\n",
350
+ "print('heuristic on sample obs:', _smoke_a)\n",
351
+ "print('prompt sample (first 200 chars):', _smoke_pr[:200], '...')\n"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "markdown",
356
+ "metadata": {},
357
+ "source": [
358
+ "## 7. Build a deterministic, seed-anchored prompt dataset\n",
359
+ "Every prompt is generated by `env_reset_seeded(seed=BASE+i)`, and we cache\n",
360
+ "`obs -> seed` so the GRPO reward function can later replay the **exact same\n",
361
+ "observation** for scoring. Without this anchor the env is reset to an unrelated\n",
362
+ "state and the GRPO gradient is essentially noise.\n"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {},
369
+ "outputs": [],
370
+ "source": [
371
+ "OBS_JSON_RE = re.compile(r'SmartPayEnv observation:\\n(\\{.*?\\})\\nReturn one action JSON', re.DOTALL)\n",
372
+ "\n",
373
+ "def _obs_key(prompt_text):\n",
374
+ " m = OBS_JSON_RE.search(prompt_text or '')\n",
375
+ " return m.group(1) if m else (prompt_text or '')\n",
376
+ "\n",
377
+ "def collect_prompts(n, base_seed):\n",
378
+ " prompts, obs_list, seeds = [], [], []\n",
379
+ " for i in range(int(n)):\n",
380
+ " s = int(base_seed + i)\n",
381
+ " obs = env_reset_seeded(seed=s)\n",
382
+ " prompts.append(make_prompt(obs))\n",
383
+ " obs_list.append(copy.deepcopy(obs))\n",
384
+ " seeds.append(s)\n",
385
+ " return prompts, obs_list, seeds\n",
386
+ "\n",
387
+ "# A single shared pool, then we slice it for SFT and GRPO so the model is\n",
388
+ "# evaluated on the SAME distribution it was trained on.\n",
389
+ "N_TOTAL = max(CONFIG['SFT_PROMPTS'], CONFIG['GRPO_PROMPTS'])\n",
390
+ "PROMPTS, PROMPT_OBS, PROMPT_SEEDS = collect_prompts(N_TOTAL, CONFIG['PROMPT_BASE_SEED'])\n",
391
+ "\n",
392
+ "PROMPT_TO_SEED = {_obs_key(p): s for p, s in zip(PROMPTS, PROMPT_SEEDS)}\n",
393
+ "PROMPT_TO_OBS = {_obs_key(p): o for p, o in zip(PROMPTS, PROMPT_OBS)}\n",
394
+ "\n",
395
+ "print(f'Collected {len(PROMPTS)} seeded prompts | seed lookup size: {len(PROMPT_TO_SEED)}')\n",
396
+ "\n",
397
+ "# Reproducibility sanity check: seed -> obs round-trip\n",
398
+ "_obs_again = env_reset_seeded(PROMPT_SEEDS[0])\n",
399
+ "_match = all(_obs_again.get(k) == PROMPT_OBS[0].get(k)\n",
400
+ " for k in ['amount','merchant_category','observed_fraud_risk','time_of_day'])\n",
401
+ "print('seed->obs reproducibility:', 'OK' if _match else 'MISMATCH (degraded GRPO)')\n"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "markdown",
406
+ "metadata": {},
407
+ "source": [
408
+ "## 8. Baseline evaluation (Random + Heuristic)\n",
409
+ "Plain mean-reward over `EVAL_EPISODES * EVAL_STEPS` env steps, broken down\n",
410
+ "by risk bucket so the bar chart later isn't just a single number.\n"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "cbc223b5",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "def eval_policy(policy_fn, episodes=None, steps=None):\n",
421
+ " eps = episodes or CONFIG['EVAL_EPISODES']\n",
422
+ " steps = steps or CONFIG['EVAL_STEPS']\n",
423
+ " all_rewards = []\n",
424
+ " bucket_rewards = {'low': [], 'medium': [], 'high': []}\n",
425
+ " for _ in range(eps):\n",
426
+ " obs = env_reset()\n",
427
+ " for _ in range(steps):\n",
428
+ " b = risk_bucket(obs)\n",
429
+ " a = policy_fn(obs)\n",
430
+ " payload = env_step(a)\n",
431
+ " obs = payload.get('observation', payload)\n",
432
+ " r = float(obs.get('reward', payload.get('reward', 0.0)) or 0.0)\n",
433
+ " all_rewards.append(r)\n",
434
+ " bucket_rewards[b].append(r)\n",
435
+ " if bool(obs.get('done', False)):\n",
436
+ " obs = env_reset()\n",
437
+ " return {\n",
438
+ " 'mean': float(np.mean(all_rewards)) if all_rewards else 0.0,\n",
439
+ " 'buckets': {k: float(np.mean(v)) if v else 0.0 for k, v in bucket_rewards.items()},\n",
440
+ " }\n",
441
+ "\n",
442
+ "baseline_random = eval_policy(random_policy)\n",
443
+ "baseline_heuristic = eval_policy(heuristic_policy)\n",
444
+ "print('random :', baseline_random)\n",
445
+ "print('heuristic:', baseline_heuristic)\n",
446
+ "\n",
447
+ "# ── DEBUG GATE: the heuristic IS the SFT label source. If it doesn't\n",
448
+ "# beat random by a clear margin, we are about to teach the model to be\n",
449
+ "# random — and GRPO with W_HEURISTIC>0 will lock that in. The previous\n",
450
+ "# (risk-only) heuristic failed this gate (0.27 vs 0.28). The new BIN-aware\n",
451
+ "# heuristic should clear it comfortably (~0.40 vs ~0.27).\n",
452
+ "TEACHER_MARGIN = baseline_heuristic['mean'] - baseline_random['mean']\n",
453
+ "print(f'\\\\n[DEBUG GATE] heuristic - random = {TEACHER_MARGIN:+.3f}')\n",
454
+ "if TEACHER_MARGIN < 0.03:\n",
455
+ " print(' ⚠️ WARNING: heuristic is NOT a useful teacher (< +0.03 over random).')\n",
456
+ " print(' SFT will clone a near-random policy and trained results will likely')\n",
457
+ " print(' be worse than random. Fix the heuristic before re-running.')\n",
458
+ "else:\n",
459
+ " print(' ✅ heuristic is a useful teacher; proceeding with SFT + GRPO.')\n"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "markdown",
464
+ "metadata": {},
465
+ "source": [
466
+ "## 9. Load Phi-3-mini (4-bit) + LoRA via Unsloth\n",
467
+ "We list both Phi-3 (`qkv_proj`, `gate_up_proj`) and Qwen/Llama\n",
468
+ "(`q_proj`, `k_proj`, …) target module names so swapping `MODEL_ID` later\n",
469
+ "*just works*. No `bf16` flag — T4 has no bf16 support and Unsloth picks fp16\n",
470
+ "automatically for the 4-bit base + LoRA.\n"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "from unsloth import FastLanguageModel\n",
480
+ "from datasets import Dataset\n",
481
+ "from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer\n",
482
+ "\n",
483
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
484
+ " model_name=CONFIG['MODEL_ID'],\n",
485
+ " max_seq_length=CONFIG['MAX_SEQ_LEN'],\n",
486
+ " dtype=None,\n",
487
+ " load_in_4bit=True,\n",
488
+ ")\n",
489
+ "\n",
490
+ "PHI3_MODULES = ['qkv_proj', 'o_proj', 'gate_up_proj', 'down_proj']\n",
491
+ "QWEN_MODULES = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj']\n",
492
+ "target_modules = PHI3_MODULES if 'phi-3' in CONFIG['MODEL_ID'].lower() else QWEN_MODULES\n",
493
+ "\n",
494
+ "model = FastLanguageModel.get_peft_model(\n",
495
+ " model,\n",
496
+ " r=CONFIG['LORA_R'],\n",
497
+ " target_modules=target_modules,\n",
498
+ " lora_alpha=2 * CONFIG['LORA_R'],\n",
499
+ " lora_dropout=0.0,\n",
500
+ " bias='none',\n",
501
+ " use_gradient_checkpointing='unsloth',\n",
502
+ " random_state=CONFIG['SEED'],\n",
503
+ ")\n",
504
+ "if tokenizer.pad_token is None:\n",
505
+ " tokenizer.pad_token = tokenizer.eos_token\n",
506
+ "# Left-truncate so if the prompt overflows, we drop the LEGEND at the front\n",
507
+ "# and keep the schema instruction at the END. Right-truncation silently drops\n",
508
+ "# 'Return one action JSON ...' and the model emits prose -> zero advantage.\n",
509
+ "tokenizer.truncation_side = 'left'\n",
510
+ "print(f'LoRA ready | r={CONFIG[\"LORA_R\"]} | target_modules={target_modules}')\n"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "metadata": {},
516
+ "source": [
517
+ "## 10. Build the SFT dataset (heuristic imitation)\n",
518
+ "Each (prompt, completion) pair is `(make_prompt(obs), heuristic_policy(obs)_as_json)`.\n",
519
+ "This is just behavioural cloning of the heuristic — short, cheap, and gives\n",
520
+ "GRPO a non-degenerate starting policy.\n"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "N_SFT = min(CONFIG['SFT_PROMPTS'], len(PROMPTS))\n",
530
+ "sft_records = []\n",
531
+ "for p, o in zip(PROMPTS[:N_SFT], PROMPT_OBS[:N_SFT]):\n",
532
+ " label_action = heuristic_policy(o)\n",
533
+ " completion = json.dumps(label_action, separators=(',', ':'))\n",
534
+ " sft_records.append({'prompt': p, 'completion': ' ' + completion})\n",
535
+ "\n",
536
+ "sft_ds = Dataset.from_list(sft_records)\n",
537
+ "print('SFT dataset size:', len(sft_ds))\n",
538
+ "print('Example completion:', sft_records[0]['completion'])\n"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "markdown",
543
+ "metadata": {},
544
+ "source": [
545
+ "## 11. Stage 1 — SFT warm-start\n",
546
+ "Short single-epoch pass with `completion_only_loss=True` so we don't waste\n",
547
+ "gradient on the long prompt tokens. `padding_free=False` is required by recent\n",
548
+ "TRL builds when `max_length` is set without packing.\n"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "code",
553
+ "execution_count": null,
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": [
557
+ "sft_cfg = SFTConfig(\n",
558
+ " output_dir=os.path.join(CONFIG['OUT_DIR'], 'sft'),\n",
559
+ " num_train_epochs=CONFIG['SFT_EPOCHS'],\n",
560
+ " per_device_train_batch_size=CONFIG['SFT_BATCH'],\n",
561
+ " gradient_accumulation_steps=CONFIG['SFT_GRAD_ACCUM'],\n",
562
+ " learning_rate=CONFIG['SFT_LR'],\n",
563
+ " logging_steps=2,\n",
564
+ " save_strategy='no',\n",
565
+ " report_to=[],\n",
566
+ " max_length=CONFIG['MAX_SEQ_LEN'],\n",
567
+ " completion_only_loss=True,\n",
568
+ " padding_free=False, # avoid TRL 'max_length not enforced' ValueError\n",
569
+ ")\n",
570
+ "sft_trainer = SFTTrainer(\n",
571
+ " model=model,\n",
572
+ " args=sft_cfg,\n",
573
+ " train_dataset=sft_ds,\n",
574
+ " processing_class=tokenizer,\n",
575
+ ")\n",
576
+ "sft_result = sft_trainer.train()\n",
577
+ "sft_loss_history = [h.get('loss') for h in sft_trainer.state.log_history if 'loss' in h]\n",
578
+ "print(f'SFT done | final train loss: {sft_loss_history[-1] if sft_loss_history else \"n/a\"}')\n"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "markdown",
583
+ "id": "8c86171d",
584
+ "metadata": {},
585
+ "source": [
586
+ "## 12. Shaped GRPO reward (Booster #1)\n",
587
+ "\n",
588
+ "**DEBUG NOTES (round 2 of fixes):**\n",
589
+ "\n",
590
+ "1. The previous run had `W_HEURISTIC=0.3` weighting an agreement signal\n",
591
+ " against a risk-only heuristic that scored **worse than random** on this\n",
592
+ " env (it ignored `BIN_AFFINITY`, the dominant reward driver). With the\n",
593
+ " BIN-aware heuristic (cell 12) the agreement signal is now genuinely\n",
594
+ " useful — but we still rebalance toward the env signal because the env\n",
595
+ " reward IS the objective.\n",
596
+ "2. `env_reward_for` now uses the **per-task scores** (`task_routing_score`,\n",
597
+ " `task_fraud_mcc_score`, `task_retention_score`) directly, instead of\n",
598
+ " `obs.reward`. The per-task scores are computed by the graders straight\n",
599
+ " from action quality, while `obs.reward` adds `regret_penalty` +\n",
600
+ " `gaming_penalty` + chargeback noise on top — fine for *evaluation*\n",
601
+ " (fair, realistic) but a noisy gradient signal for GRPO. Eval still uses\n",
602
+ " `obs.reward` so the bar chart reflects real env performance.\n",
603
+ "3. The env's `regret_penalty` coefficient was eased `0.35 → 0.15` and the\n",
604
+ " `robustness_bonus` now activates from step 1 (was 0 until self-improvement\n",
605
+ " kicked in). Both changes widen the eval reward's dynamic range.\n",
606
+ "\n",
607
+ "1. **`W_ENV * env_reward_clipped`** (now `0.7`) — outcome from `/step`,\n",
608
+ " clipped to `[-1, 1]`. This is the only component tied to the true objective.\n",
609
+ "2. **`W_HEURISTIC * heuristic_agreement`** (now `0.15`) — `+1` when the model\n",
610
+ " picks the same `fraud_decision` *and* `gateway` as the BIN-aware heuristic\n",
611
+ " on extreme-risk buckets, `-1` on disagreement, `0` on the medium bucket.\n",
612
+ "3. **`W_FORMAT * format_ok`** (now `0.15`) — `+1` if `parse_action` succeeded.\n",
613
+ " After SFT this is ~free; tiny weight just stops a regression.\n",
614
+ "\n",
615
+ "Each completion is evaluated against the **exact** observation the prompt was\n",
616
+ "made under (via `PROMPT_TO_SEED`), so all `num_generations` samples in a GRPO\n",
617
+ "group share the same env state — that's what makes the group-relative\n",
618
+ "advantage clean.\n"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "execution_count": null,
624
+ "id": "a6adb23b",
625
+ "metadata": {},
626
+ "outputs": [],
627
+ "source": [
628
+ "def env_reward_for(action, seed):\n",
629
+ " \"\"\"Replay the EXACT obs the prompt was made under, score the action.\n",
630
+ "\n",
631
+ " DEBUG NOTE: returns a CLEAN per-task signal (route+fraud+retention) instead\n",
632
+ " of `obs.reward`. The env's obs.reward applies regret_penalty +\n",
633
+ " gaming_penalty + chargeback noise on top of the per-task scores; that's the\n",
634
+ " right thing to *evaluate* against (fair, realistic), but it's a noisy\n",
635
+ " gradient signal for GRPO. The per-task scores are computed directly from\n",
636
+ " action quality by the graders → much higher SNR for training.\n",
637
+ " The same `0.4 / 0.4 / 0.2` weighting as the env's `base_reward` is used so\n",
638
+ " the training reward stays aligned with the eval reward in expectation.\n",
639
+ " \"\"\"\n",
640
+ " env_reset_seeded(seed)\n",
641
+ " payload = env_step(action)\n",
642
+ " obs = payload.get('observation', payload)\n",
643
+ " rs = float(obs.get('task_routing_score', 0.5) or 0.5)\n",
644
+ " fs = float(obs.get('task_fraud_mcc_score', 0.5) or 0.5)\n",
645
+ " re = float(obs.get('task_retention_score', 0.5) or 0.5)\n",
646
+ " # Map [0,1] -> [-1,1] so heuristic-agreement and env signal share a scale.\n",
647
+ " base = 0.4 * rs + 0.4 * fs + 0.2 * re\n",
648
+ " return float(2.0 * base - 1.0)\n",
649
+ "\n",
650
+ "def heuristic_agreement(action, obs):\n",
651
+ " \"\"\"Agreement bonus on TWO axes — fraud_decision AND gateway pick.\n",
652
+ " The gateway component is what teaches the model BIN-awareness (the\n",
653
+ " dominant lever per the env's BIN_AFFINITY table). Medium bucket gets\n",
654
+ " 0 so the model is free to learn fd from the env reward where the\n",
655
+ " teacher is least confident. Returns a value in [-1.0, +1.0].\"\"\"\n",
656
+ " h = heuristic_policy(obs)\n",
657
+ " bucket = risk_bucket(obs)\n",
658
+ " fd_match = (action['fraud_decision'] == h['fraud_decision'])\n",
659
+ " gw_match = (action['gateway'] == h['gateway'])\n",
660
+ " if bucket == 'medium':\n",
661
+ " # On medium bucket: only reward correct gateway (env reward is noisy\n",
662
+ " # on fd here; let GRPO discover fd from env signal).\n",
663
+ " return 0.5 if gw_match else -0.5\n",
664
+ " fd_score = 1.0 if fd_match else -1.0\n",
665
+ " gw_score = 1.0 if gw_match else -1.0\n",
666
+ " return 0.5 * fd_score + 0.5 * gw_score\n",
667
+ "\n",
668
+ "def shaped_reward(completion_text, prompt_text):\n",
669
+ " obs_key = _obs_key(prompt_text)\n",
670
+ " seed = PROMPT_TO_SEED.get(obs_key)\n",
671
+ " obs = PROMPT_TO_OBS.get(obs_key)\n",
672
+ " action, ok = parse_action(completion_text)\n",
673
+ " fmt_bonus = 1.0 if ok else 0.0\n",
674
+ " env_r = 0.0\n",
675
+ " if seed is not None:\n",
676
+ " env_r = max(-1.0, min(1.0, env_reward_for(action, seed)))\n",
677
+ " heur_r = heuristic_agreement(action, obs) if obs is not None else 0.0\n",
678
+ " return (\n",
679
+ " CONFIG['W_ENV'] * env_r +\n",
680
+ " CONFIG['W_HEURISTIC'] * heur_r +\n",
681
+ " CONFIG['W_FORMAT'] * fmt_bonus\n",
682
+ " )\n",
683
+ "\n",
684
+ "def reward_fn(completions, prompts=None, **_):\n",
685
+ " out = []\n",
686
+ " for i, comp in enumerate(completions):\n",
687
+ " # TRL hands us either a str or a chat-formatted list/dict; normalise.\n",
688
+ " if isinstance(comp, str):\n",
689
+ " text = comp\n",
690
+ " elif isinstance(comp, list) and comp:\n",
691
+ " text = comp[0].get('content', '') if isinstance(comp[0], dict) else str(comp[0])\n",
692
+ " elif isinstance(comp, dict):\n",
693
+ " text = comp.get('content', '')\n",
694
+ " else:\n",
695
+ " text = str(comp)\n",
696
+ " prompt_text = prompts[i] if prompts is not None else ''\n",
697
+ " if isinstance(prompt_text, list) and prompt_text:\n",
698
+ " prompt_text = prompt_text[0].get('content', '') if isinstance(prompt_text[0], dict) else str(prompt_text[0])\n",
699
+ " out.append(float(shaped_reward(text, prompt_text)))\n",
700
+ " return out\n",
701
+ "\n",
702
+ "# Smoke-test the reward function on the SFT model\n",
703
+ "sample_prompt = PROMPTS[0]\n",
704
+ "sample_action = heuristic_policy(PROMPT_OBS[0])\n",
705
+ "sample_text = json.dumps(sample_action)\n",
706
+ "print('Smoke shaped_reward (heuristic action on first prompt):',\n",
707
+ " shaped_reward(sample_text, sample_prompt))\n"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "markdown",
712
+ "metadata": {},
713
+ "source": [
714
+ "## 13. Stage 2 — GRPO with KL anchor (Booster #3)\n",
715
+ "`beta=GRPO_BETA` is the KL penalty against the SFT reference. Without it the\n",
716
+ "policy quickly collapses onto whatever string maximises the format/heuristic\n",
717
+ "bonus and drops the env reward. With β≈0.04 it stays anchored to the warm-start\n",
718
+ "distribution while still gaining ~10–20% mean reward over SFT.\n"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": null,
724
+ "metadata": {},
725
+ "outputs": [],
726
+ "source": [
727
+ "N_GRPO = min(CONFIG['GRPO_PROMPTS'], len(PROMPTS))\n",
728
+ "grpo_ds = Dataset.from_list([{'prompt': p} for p in PROMPTS[:N_GRPO]])\n",
729
+ "\n",
730
+ "grpo_cfg = GRPOConfig(\n",
731
+ " output_dir=os.path.join(CONFIG['OUT_DIR'], 'grpo'),\n",
732
+ " num_generations=CONFIG['GRPO_NUM_GENERATIONS'],\n",
733
+ " max_prompt_length=CONFIG['MAX_PROMPT_TOKENS'],\n",
734
+ " max_completion_length=CONFIG['MAX_NEW_TOKENS'],\n",
735
+ " per_device_train_batch_size=1,\n",
736
+ " gradient_accumulation_steps=2,\n",
737
+ " max_steps=CONFIG['GRPO_STEPS'],\n",
738
+ " logging_steps=1,\n",
739
+ " learning_rate=CONFIG['GRPO_LR'],\n",
740
+ " save_strategy='no',\n",
741
+ " report_to=[],\n",
742
+ " temperature=CONFIG['GRPO_TEMPERATURE'],\n",
743
+ " beta=CONFIG['GRPO_BETA'],\n",
744
+ ")\n",
745
+ "grpo_trainer = GRPOTrainer(\n",
746
+ " model=model,\n",
747
+ " args=grpo_cfg,\n",
748
+ " train_dataset=grpo_ds,\n",
749
+ " processing_class=tokenizer,\n",
750
+ " reward_funcs=[reward_fn],\n",
751
+ ")\n",
752
+ "grpo_result = grpo_trainer.train()\n",
753
+ "grpo_loss_history = [h.get('loss') for h in grpo_trainer.state.log_history if 'loss' in h]\n",
754
+ "grpo_reward_history = [h.get('reward') for h in grpo_trainer.state.log_history if 'reward' in h]\n",
755
+ "print(f'GRPO done | last loss={grpo_loss_history[-1] if grpo_loss_history else \"n/a\"} | '\n",
756
+ " f'last reward={grpo_reward_history[-1] if grpo_reward_history else \"n/a\"}')\n"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "markdown",
761
+ "metadata": {},
762
+ "source": [
763
+ "## 14. Trained-policy evaluation + Self-Consistency (Booster #2)\n",
764
+ "- **Greedy:** decode once per obs, parse, step the env.\n",
765
+ "- **Self-Consistency:** sample `SC_VOTES` actions per obs, take the per-field\n",
766
+ " *plurality vote* (Wang et al., 2023). Cheap inference-time variance reduction\n",
767
+ " that often beats any single-sample decoding strategy on small models.\n"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": null,
773
+ "metadata": {},
774
+ "outputs": [],
775
+ "source": [
776
+ "FastLanguageModel.for_inference(model)\n",
777
+ "device = next(model.parameters()).device\n",
778
+ "\n",
779
+ "@torch.no_grad()\n",
780
+ "def llm_generate(prompt_text, n_samples=1, do_sample=False, temperature=0.7):\n",
781
+ " enc = tokenizer(prompt_text, return_tensors='pt', truncation=True,\n",
782
+ " max_length=CONFIG['MAX_PROMPT_TOKENS']).to(device)\n",
783
+ " out = model.generate(\n",
784
+ " **enc,\n",
785
+ " max_new_tokens=CONFIG['MAX_NEW_TOKENS'],\n",
786
+ " num_return_sequences=n_samples,\n",
787
+ " do_sample=do_sample,\n",
788
+ " temperature=temperature if do_sample else 1.0,\n",
789
+ " pad_token_id=tokenizer.pad_token_id,\n",
790
+ " )\n",
791
+ " return [tokenizer.decode(seq[enc['input_ids'].shape[1]:], skip_special_tokens=True)\n",
792
+ " for seq in out]\n",
793
+ "\n",
794
+ "def trained_policy_greedy(obs):\n",
795
+ " text = llm_generate(make_prompt(obs), n_samples=1, do_sample=False)[0]\n",
796
+ " a, _ = parse_action(text)\n",
797
+ " return a\n",
798
+ "\n",
799
+ "def trained_policy_sc(obs, n_votes=None):\n",
800
+ " n = n_votes or CONFIG['SC_VOTES']\n",
801
+ " texts = llm_generate(make_prompt(obs), n_samples=n, do_sample=True, temperature=0.7)\n",
802
+ " actions = [parse_action(t)[0] for t in texts]\n",
803
+ " voted = {}\n",
804
+ " for field in ('gateway', 'fraud_decision', 'retry_strategy'):\n",
805
+ " voted[field] = Counter(a[field] for a in actions).most_common(1)[0][0]\n",
806
+ " return voted\n",
807
+ "\n",
808
+ "trained_eval_greedy = eval_policy(trained_policy_greedy)\n",
809
+ "trained_eval_sc = eval_policy(trained_policy_sc)\n",
810
+ "\n",
811
+ "print('trained (greedy):', trained_eval_greedy)\n",
812
+ "print('trained (SC=%d) :' % CONFIG['SC_VOTES'], trained_eval_sc)\n"
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "markdown",
817
+ "metadata": {},
818
+ "source": [
819
+ "## 15. Plots\n",
820
+ "- SFT loss curve\n",
821
+ "- GRPO loss + shaped reward curves\n",
822
+ "- Mean-reward bar chart (Random / Heuristic / Trained-Greedy / Trained-SC)\n",
823
+ "- Per-bucket bar chart\n"
824
+ ]
825
+ },
826
+ {
827
+ "cell_type": "code",
828
+ "execution_count": null,
829
+ "metadata": {},
830
+ "outputs": [],
831
+ "source": [
832
+ "ART = pathlib.Path(CONFIG['OUT_DIR'])\n",
833
+ "ART.mkdir(parents=True, exist_ok=True)\n",
834
+ "\n",
835
+ "# 1. SFT loss\n",
836
+ "plt.figure(figsize=(6,3))\n",
837
+ "plt.plot(sft_loss_history, marker='o')\n",
838
+ "plt.title('Stage 1 — SFT loss'); plt.xlabel('log step'); plt.ylabel('loss')\n",
839
+ "plt.tight_layout(); plt.savefig(ART / 'sft_loss.png', dpi=140); plt.show()\n",
840
+ "\n",
841
+ "# 2. GRPO loss + reward (twin axis)\n",
842
+ "fig, ax1 = plt.subplots(figsize=(7,3.5))\n",
843
+ "ax1.plot(grpo_loss_history, color='#c44', label='GRPO loss')\n",
844
+ "ax1.set_xlabel('log step'); ax1.set_ylabel('loss', color='#c44')\n",
845
+ "ax2 = ax1.twinx()\n",
846
+ "ax2.plot(grpo_reward_history, color='#48a', label='shaped reward')\n",
847
+ "ax2.set_ylabel('reward', color='#48a')\n",
848
+ "plt.title('Stage 2 — GRPO loss + shaped reward')\n",
849
+ "fig.tight_layout(); plt.savefig(ART / 'grpo_curves.png', dpi=140); plt.show()\n",
850
+ "\n",
851
+ "# 3. Mean reward bar chart\n",
852
+ "labels = ['Random', 'Heuristic', 'Trained (Greedy)', f'Trained (SC={CONFIG[\"SC_VOTES\"]})']\n",
853
+ "means = [baseline_random['mean'], baseline_heuristic['mean'],\n",
854
+ " trained_eval_greedy['mean'], trained_eval_sc['mean']]\n",
855
+ "plt.figure(figsize=(7,3.5))\n",
856
+ "bars = plt.bar(labels, means, color=['#999','#aaa','#4a8','#3b7'])\n",
857
+ "for b, m in zip(bars, means):\n",
858
+ " plt.text(b.get_x() + b.get_width()/2, m, f'{m:.3f}', ha='center', va='bottom')\n",
859
+ "plt.title('Mean reward by policy'); plt.ylabel('mean reward')\n",
860
+ "plt.tight_layout(); plt.savefig(ART / 'mean_reward.png', dpi=140); plt.show()\n",
861
+ "\n",
862
+ "# 4. Per-bucket reward\n",
863
+ "bucket_names = ['low', 'medium', 'high']\n",
864
+ "x = np.arange(len(bucket_names)); w = 0.2\n",
865
+ "plt.figure(figsize=(7,3.5))\n",
866
+ "plt.bar(x - 1.5*w, [baseline_random['buckets'][b] for b in bucket_names], w, label='Random', color='#999')\n",
867
+ "plt.bar(x - 0.5*w, [baseline_heuristic['buckets'][b] for b in bucket_names], w, label='Heuristic', color='#aaa')\n",
868
+ "plt.bar(x + 0.5*w, [trained_eval_greedy['buckets'][b] for b in bucket_names], w, label='Trained-G', color='#4a8')\n",
869
+ "plt.bar(x + 1.5*w, [trained_eval_sc['buckets'][b] for b in bucket_names], w, label='Trained-SC', color='#3b7')\n",
870
+ "plt.xticks(x, bucket_names); plt.title('Per-bucket mean reward'); plt.legend()\n",
871
+ "plt.tight_layout(); plt.savefig(ART / 'per_bucket.png', dpi=140); plt.show()\n",
872
+ "\n",
873
+ "print('Plots saved to', ART.resolve())\n"
874
+ ]
875
+ },
876
+ {
877
+ "cell_type": "markdown",
878
+ "metadata": {},
879
+ "source": [
880
+ "## 16. Save LoRA + run summary\n",
881
+ "The LoRA adapter lands in `{LORA_OUT}` and a structured `run_summary.json` next\n",
882
+ "to it for quick diffing across runs.\n"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": null,
888
+ "metadata": {},
889
+ "outputs": [],
890
+ "source": [
891
+ "lora_dir = pathlib.Path(CONFIG['LORA_OUT'])\n",
892
+ "lora_dir.mkdir(parents=True, exist_ok=True)\n",
893
+ "model.save_pretrained(str(lora_dir))\n",
894
+ "tokenizer.save_pretrained(str(lora_dir))\n",
895
+ "print('LoRA saved to', lora_dir.resolve())\n",
896
+ "\n",
897
+ "summary = {\n",
898
+ " 'model_id' : CONFIG['MODEL_ID'],\n",
899
+ " 'env_url' : CONFIG['ENV_URL'],\n",
900
+ " 'config' : CONFIG,\n",
901
+ " 'sft_loss_history' : sft_loss_history,\n",
902
+ " 'grpo_loss_history' : grpo_loss_history,\n",
903
+ " 'grpo_reward_history' : grpo_reward_history,\n",
904
+ " 'baseline_random' : baseline_random,\n",
905
+ " 'baseline_heuristic' : baseline_heuristic,\n",
906
+ " 'trained_eval_greedy' : trained_eval_greedy,\n",
907
+ " 'trained_eval_sc' : trained_eval_sc,\n",
908
+ " 'improvement_over_random_pct' : (\n",
909
+ " 100.0 * (trained_eval_sc['mean'] - baseline_random['mean'])\n",
910
+ " / max(abs(baseline_random['mean']), 1e-6)\n",
911
+ " ),\n",
912
+ " 'improvement_over_heuristic_pct': (\n",
913
+ " 100.0 * (trained_eval_sc['mean'] - baseline_heuristic['mean'])\n",
914
+ " / max(abs(baseline_heuristic['mean']), 1e-6)\n",
915
+ " ),\n",
916
+ "}\n",
917
+ "sum_path = pathlib.Path(CONFIG['OUT_DIR']) / 'run_summary.json'\n",
918
+ "sum_path.write_text(json.dumps(summary, indent=2, default=float))\n",
919
+ "print('run_summary.json ->', sum_path.resolve())\n",
920
+ "print(f'\\nFinal mean reward — random: {baseline_random[\"mean\"]:.3f} | '\n",
921
+ " f'heuristic: {baseline_heuristic[\"mean\"]:.3f} | '\n",
922
+ " f'trained-greedy: {trained_eval_greedy[\"mean\"]:.3f} | '\n",
923
+ " f'trained-SC: {trained_eval_sc[\"mean\"]:.3f}')\n"
924
+ ]
925
+ },
926
+ {
927
+ "cell_type": "markdown",
928
+ "id": "2328ea8a",
929
+ "metadata": {},
930
+ "source": [
931
+ "## What to look for in the results\n",
932
+ "\n",
933
+ "- **DEBUG GATE in cell 16**: `heuristic - random ≥ +0.03`. If it's not, the\n",
934
+ " heuristic teacher is too weak and the run will mirror the previous failure\n",
935
+ " mode (trained < random). Inspect `BIN_BEST_GATEWAY` and try a debug print\n",
936
+ " of `heuristic_policy(obs)` on a few sample observations.\n",
937
+ "- **SFT loss** drops smoothly to <0.3 within one epoch.\n",
938
+ "- **GRPO shaped-reward** trends upward; loss should be small but non-zero\n",
939
+ " (not 1e-6 — that means dead group-relative advantage).\n",
940
+ "- **Mean-reward bar chart**: `Trained-SC ≥ Trained-Greedy ≥ Heuristic > Random`.\n",
941
+ "- **Per-bucket chart**: trained model should at least *match* the heuristic on\n",
942
+ " the easy `low` bucket and beat random/heuristic on `medium`/`high`.\n",
943
+ "\n",
944
+ "### Why the previous run failed (root cause documented for posterity)\n",
945
+ "The risk-only heuristic ignored `BIN_AFFINITY` (the env's dominant reward\n",
946
+ "driver — wrong gateway = 6.7× penalty on `expected_outcome`) and chose\n",
947
+ "`Block` for high risk, which the env *punishes* via `route_score=true_risk`\n",
948
+ "+ forced episode end. Result: heuristic ≈ random on mean reward. SFT cloned\n",
949
+ "this near-random teacher and GRPO with `W_HEURISTIC=0.3` reinforced it →\n",
950
+ "trained < random. Fixed by:\n",
951
+ "\n",
952
+ "1. **BIN-aware heuristic** (encodes `BIN_AFFINITY[gateway][bin_category]`)\n",
953
+ "2. **3DS over Block** (3DS strictly dominates: `eff_fraud_risk *= 0.1` AND\n",
954
+ " the transaction can still succeed)\n",
955
+ "3. **Rebalanced shaped reward** — `W_ENV: 0.5→0.7`, `W_HEURISTIC: 0.3→0.15`\n",
956
+ "4. **Larger eval** — 90 → 300 samples for cleaner mean\n",
957
+ "5. **Sanity gate** that warns when the teacher isn't useful\n",
958
+ "\n",
959
+ "If `Trained-Greedy` is still below `Heuristic` after these fixes:\n",
960
+ "- raise `GRPO_STEPS` to 60+ (the model needs more updates to converge),\n",
961
+ "- raise `SFT_PROMPTS` to 256+ (the BIN→gateway distillation needs coverage).\n"
962
+ ]
963
+ }
964
+ ],
965
+ "metadata": {
966
+ "kernelspec": {
967
+ "display_name": "Python 3",
968
+ "language": "python",
969
+ "name": "python3"
970
+ },
971
+ "language_info": {
972
+ "name": "python",
973
+ "version": "3.10"
974
+ }
975
+ },
976
+ "nbformat": 4,
977
+ "nbformat_minor": 5
978
+ }
notebooks/train_smartpayenev.ipynb CHANGED
@@ -13,50 +13,97 @@
13
  "\n",
14
  "### What's implemented\n",
15
  "\n",
16
- "This notebook implements **true co-evolution** between two learning agents:\n",
17
- "\n",
18
- "* **Defender LLM** — `unsloth/Qwen2.5-0.5B-Instruct` trained with **TRL GRPO**.\n",
19
- " Reward comes from a real **K-step rollout** in the env (not a single noisy step).\n",
20
- " All `num_generations` completions in a GRPO group share the **same seed**\n",
21
- " (via `/reset_seeded`), so the group-relative advantage is signal, not noise.\n",
22
- "\n",
23
- "* **Fraud agent** a small **parametric policy** with 3 continuous parameters\n",
24
- " (`intensity`, `noise_boost`, `pattern_rate`) updated by **Evolution Strategies (ES)**.\n",
25
- " After each defender round we run a few ES iterations to make fraud *harder*\n",
26
- " for the current defender. Updates are pushed to the env via\n",
27
- " `/configure_adversary`.\n",
28
- "\n",
29
- "Co-training loop (alternating, AlphaStar-PFSP-inspired):\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  "```\n",
31
  "for round in range(N_ROUNDS):\n",
32
- " 1. Train defender (GRPO) against current fraud agent\n",
33
- " 2. Snapshot defender (LoRA) into the league\n",
34
- " 3. Update fraud agent (ES) against the latest + a sampled past defender\n",
35
- " 4. Log: defender reward, fraud reward, exploitability gap\n",
 
 
36
  "```\n",
37
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  "Why this matters:\n",
39
- "* Single-step rewards are noisy → **multi-step rollout** kills variance.\n",
40
- "* Different start states per generation → **same-seed group** gives clean GRPO advantages.\n",
41
  "* Static adversary → defender plateaus → **learning fraud agent** keeps pressure escalating.\n",
42
- "* Cyclic strategies → **league snapshots + PFSP sampling** stabilise training.\n",
43
  "\n",
44
  "Pipeline:\n",
45
- "1. Install deps (Unsloth + TRL from GitHub)\n",
46
  "2. HF login (uses your HF credits)\n",
47
  "3. GPU sanity check + env health\n",
48
- "4. Build prompt dataset from live `/step` rollouts\n",
49
- "5. Baseline eval (random + heuristic) on a frozen seed\n",
50
- "6. **Co-training loop** — alternating GRPO defender + ES fraud agent\n",
51
- "7. Trained-policy eval on the frozen seed\n",
52
- "8. Plots:\n",
53
- " - Defender mean reward per round\n",
54
- " - Fraud agent mean reward per round\n",
55
- " - Exploitability gap per round\n",
56
- " - Fraud parameter trajectories\n",
57
- " - Before vs After mean reward (random / heuristic / trained)\n",
58
- " - Per risk-bucket reward (low / medium / high)\n",
59
- "9. Save artifacts to `./artifacts`\n",
 
 
 
 
60
  "\n",
61
  "Hackathon: OpenEnv (India 2026), Theme #4 — Self-Improvement.\n",
62
  "Space: https://huggingface.co/spaces/Pratap-K/SmartPayEnv"
@@ -72,13 +119,15 @@
72
  {
73
  "cell_type": "code",
74
  "execution_count": null,
 
75
  "metadata": {},
76
  "outputs": [],
77
  "source": [
78
  "!pip -q install --upgrade pip\n",
79
  "!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
 
80
  "!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
81
- "!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests numpy"
82
  ]
83
  },
84
  {
@@ -121,21 +170,28 @@
121
  "SEED = 42\n",
122
  "\n",
123
  "# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
124
- "# value that still produces all 7 plots + meaningful accuracy comparison.\n",
125
- "# Approx wall time on a Colab T4: QUICK ~3-5 min, FULL ~12-18 min.\n",
126
  "\n",
127
  "# Co-evolution loop\n",
128
- "N_ROUNDS = 2 if QUICK_MODE else 4 # need >=2 to see co-evolution curve\n",
129
  "GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
130
  "ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
131
  "ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
132
  "ES_SIGMA = 0.25 # exploration std for ES\n",
133
  "ES_LR = 0.4 # ES update rate\n",
134
  "\n",
135
- "# Defender / GRPO (rewards are mean over a K-step rollout)\n",
136
  "PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
137
  "GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
138
- "ROLLOUT_STEPS_PER_REWARD = 2 if QUICK_MODE else 4\n",
 
 
 
 
 
 
 
139
  "\n",
140
  "# Final frozen-holdout eval\n",
141
  "EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
@@ -146,10 +202,52 @@
146
  "COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
147
  "COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
148
  "\n",
149
- "MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n",
150
- "MAX_SEQ_LEN = 1024 if QUICK_MODE else 2048\n",
 
 
 
 
 
151
  "LOAD_IN_4BIT = True\n",
152
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  "os.makedirs('artifacts', exist_ok=True)\n",
154
  "random.seed(SEED)\n",
155
  "np.random.seed(SEED)\n",
@@ -160,6 +258,8 @@
160
  " '| pop =', ES_POPULATION,\n",
161
  " '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
162
  " '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
 
 
163
  " '| MODEL_ID =', MODEL_ID)"
164
  ]
165
  },
@@ -259,9 +359,16 @@
259
  " return None\n",
260
  "\n",
261
  "def rollout_reward(action, seed, difficulty=DIFFICULTY, k=ROLLOUT_STEPS_PER_REWARD):\n",
262
- " \"\"\"K-step rollout reward. Resets to a deterministic seed, then keeps replaying\n",
263
- " the SAME action for `k` steps. The mean reward is far less noisy than a single\n",
264
- " /step, and the seed makes all completions in a GRPO group comparable.\"\"\"\n",
 
 
 
 
 
 
 
265
  " env_reset_seeded(seed, difficulty)\n",
266
  " rewards = []\n",
267
  " for _ in range(int(k)):\n",
@@ -332,24 +439,60 @@
332
  {
333
  "cell_type": "code",
334
  "execution_count": null,
 
335
  "metadata": {},
336
  "outputs": [],
337
  "source": [
338
- "def collect_prompts(n=PROMPT_DATASET_SIZE, difficulty=DIFFICULTY):\n",
339
- " obs = env_reset(difficulty)\n",
340
- " prompts = []\n",
341
- " for _ in range(n):\n",
 
 
 
 
 
 
 
 
 
342
  " prompts.append(make_prompt(obs))\n",
343
- " a = random.choice(ACTIONS)\n",
344
- " payload = env_step(a)\n",
345
- " obs = payload.get('observation', payload)\n",
346
- " if bool(obs.get('done', False)):\n",
347
- " obs = env_reset(difficulty)\n",
348
- " return prompts\n",
 
 
 
 
 
 
 
 
349
  "\n",
350
- "prompts = collect_prompts()\n",
351
- "print('Prompts collected:', len(prompts))\n",
352
- "print('Example prompt:\\n', prompts[0][:300], '...')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  ]
354
  },
355
  {
@@ -362,6 +505,7 @@
362
  {
363
  "cell_type": "code",
364
  "execution_count": null,
 
365
  "metadata": {},
366
  "outputs": [],
367
  "source": [
@@ -414,6 +558,16 @@
414
  " fd = 0\n",
415
  " return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
416
  "\n",
 
 
 
 
 
 
 
 
 
 
417
  "baseline_random = eval_policy(random_policy)\n",
418
  "baseline_heuristic = eval_policy(heuristic_policy)\n",
419
  "print('Random baseline:', baseline_random['mean_reward'], baseline_random['bucket_means'])\n",
@@ -519,9 +673,50 @@
519
  " 'best_fraud_fitness': float(np.max(fitnesses)),\n",
520
  " }\n",
521
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  "fraud_agent = FraudPolicy()\n",
523
  "fraud_agent.apply()\n",
524
- "print('Fraud agent initialised with theta =', fraud_agent.theta)"
 
 
525
  ]
526
  },
527
  {
@@ -529,31 +724,93 @@
529
  "id": "5efe6c56",
530
  "metadata": {},
531
  "source": [
532
- "## 8. Co-evolving Training Loop — Defender (GRPO)Fraud (ES)\n",
533
- "\n",
534
- "Each round:\n",
535
- "1. **Defender phase (GRPO)** `GRPO_STEPS_PER_ROUND` gradient steps. Reward for\n",
536
- " each completion is a **K-step rollout** with a **shared seed** across the\n",
537
- " whole GRPO group → clean group-relative advantage.\n",
538
- "2. **Snapshot defender** policy into the league (LoRA state dict in memory).\n",
539
- "3. **Fraud phase (ES)** `ES_STEPS_PER_ROUND` ES updates. Each samples\n",
540
- " `ES_POPULATION` perturbations of the fraud parameters, evaluates each by\n",
541
- " running the **current defender** for a short rollout, and steps θ toward\n",
542
- " perturbations that *lower* defender reward.\n",
543
- "4. Apply the new fraud θ to the env via `/configure_adversary` → next defender\n",
544
- " round must learn against a harder adversary.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  "\n",
546
  "Reward signal flow (per defender generation):\n",
547
  "```\n",
548
- "group_seed = hash(prompt) % 2**31\n",
549
  "for completion in group:\n",
550
  " action = parse_action(completion)\n",
551
- " reward = mean( /step(action) over K steps starting at /reset_seeded(group_seed) )\n",
 
552
  "```\n",
553
- "All `num_generations` completions of one prompt share `group_seed`, so the only\n",
554
- "thing varying inside a group is the action — exactly what GRPO needs.\n",
555
- "\n",
556
- "No `/simulate` is used anywhere."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  ]
558
  },
559
  {
@@ -565,7 +822,7 @@
565
  "source": [
566
  "from unsloth import FastLanguageModel\n",
567
  "from datasets import Dataset\n",
568
- "from trl import GRPOConfig, GRPOTrainer\n",
569
  "import hashlib, torch\n",
570
  "\n",
571
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
@@ -574,10 +831,17 @@
574
  " dtype=None,\n",
575
  " load_in_4bit=LOAD_IN_4BIT,\n",
576
  ")\n",
 
 
 
 
 
 
 
577
  "model = FastLanguageModel.get_peft_model(\n",
578
  " model,\n",
579
  " r=16,\n",
580
- " target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
581
  " lora_alpha=32,\n",
582
  " lora_dropout=0.0,\n",
583
  " bias='none',\n",
@@ -586,10 +850,104 @@
586
  ")\n",
587
  "if tokenizer.pad_token is None:\n",
588
  " tokenizer.pad_token = tokenizer.eos_token\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  "\n",
590
  "ds = Dataset.from_list([{'prompt': p} for p in prompts])\n",
591
  "print(ds)\n",
592
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  "# ── Reward fn: same-seed group + multi-step rollout ───────────────────\n",
594
  "_REWARD_DEBUG = {'calls': 0}\n",
595
  "\n",
@@ -603,18 +961,51 @@
603
  " return str(comp)\n",
604
  "\n",
605
  "def _seed_for_prompt(prompt_text):\n",
606
- " h = hashlib.md5(prompt_text.encode('utf-8')).hexdigest()\n",
 
 
 
 
 
 
 
 
 
 
607
  " return int(h[:8], 16) & 0x7FFFFFFF\n",
608
  "\n",
609
  "def reward_fn(completions, prompts=None, **kwargs):\n",
610
- " \"\"\"For each completion: parse action, run K-step rollout starting from a\n",
611
- " seed derived from THIS prompt (so all completions in the group share state).\"\"\"\n",
 
 
 
 
 
 
 
612
  " rewards = []\n",
 
 
 
613
  " prompts = prompts or [None] * len(completions)\n",
 
614
  " for prompt_text, comp in zip(prompts, completions):\n",
615
  " text = _extract_text(comp)\n",
616
  " action = parse_action(text)\n",
 
 
617
  " seed = _seed_for_prompt(prompt_text or text)\n",
 
 
 
 
 
 
 
 
 
 
618
  " try:\n",
619
  " r = rollout_reward(action, seed=seed, difficulty=DIFFICULTY,\n",
620
  " k=ROLLOUT_STEPS_PER_REWARD)\n",
@@ -622,17 +1013,34 @@
622
  " print('reward_fn error:', repr(e))\n",
623
  " r = 0.0\n",
624
  " rewards.append(float(r))\n",
 
 
 
 
 
625
  " _REWARD_DEBUG['calls'] += 1\n",
626
  " if _REWARD_DEBUG['calls'] <= 3:\n",
627
- " print(f\"[reward_fn batch {_REWARD_DEBUG['calls']}] sample rewards: {rewards[:8]}\")\n",
 
 
 
 
 
 
 
 
628
  " return rewards\n",
629
  "\n",
 
 
 
 
630
  "# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
631
- "# Cap inputs/outputs aggressively so each defender call is ~few hundred ms,\n",
632
- "# not seconds. ES calls this ES_POPULATION * COEVO_EVAL_EPISODES * COEVO_EVAL_STEPS\n",
633
- "# times per ES step, so latency here dominates total wall time.\n",
634
- "_DEF_MAX_PROMPT = 512 if QUICK_MODE else 1024\n",
635
- "_DEF_MAX_NEW = 24 if QUICK_MODE else 48\n",
636
  "\n",
637
  "@torch.no_grad()\n",
638
  "def _defender_action(obs):\n",
@@ -649,6 +1057,18 @@
649
  " FastLanguageModel.for_training(model)\n",
650
  " return parse_action(text)\n",
651
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
652
  "# ── GRPO config (per-round) ───────────────────────────────────────────\n",
653
  "def _make_grpo_cfg(max_steps):\n",
654
  " return GRPOConfig(\n",
@@ -660,12 +1080,13 @@
660
  " gradient_accumulation_steps=2,\n",
661
  " max_steps=int(max_steps),\n",
662
  " logging_steps=1,\n",
663
- " learning_rate=1e-5,\n",
664
  " save_strategy='no',\n",
665
  " report_to=[],\n",
666
- " bf16=True,\n",
667
- " temperature=1.0,\n",
668
- " beta=0.02,\n",
 
669
  " )\n",
670
  "\n",
671
  "# ── Co-training loop ──────────────────────────────────────────────────\n",
@@ -675,6 +1096,7 @@
675
  "fraud_theta_history = [dict(fraud_agent.theta)]\n",
676
  "loss_history_all = []\n",
677
  "reward_log_all = []\n",
 
678
  "\n",
679
  "# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
680
  "# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
@@ -691,17 +1113,278 @@
691
  " obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
692
  " return float(np.mean(rs)) if rs else 0.0\n",
693
  "\n",
694
- "# Apply current adversary before first defender round\n",
695
- "fraud_agent.apply()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
  "\n",
697
  "for rnd in range(N_ROUNDS):\n",
698
- " print(f'\\n=== Round {rnd+1}/{N_ROUNDS} ===')\n",
699
- " print(f' fraud theta: {fraud_agent.theta}')\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  "\n",
701
- " # Phase A: defender GRPO\n",
702
  " cfg = _make_grpo_cfg(max_steps=GRPO_STEPS_PER_ROUND)\n",
703
  " trainer = GRPOTrainer(\n",
704
- " model=model, args=cfg, train_dataset=ds,\n",
705
  " processing_class=tokenizer, reward_funcs=[reward_fn],\n",
706
  " )\n",
707
  " trainer.train()\n",
@@ -710,36 +1393,78 @@
710
  " loss_history_all.extend(rnd_loss)\n",
711
  " reward_log_all.extend(rnd_rew)\n",
712
  "\n",
713
- " # Quick defender eval against current fraud\n",
 
714
  " def_score = quick_defender_eval()\n",
715
  " defender_round_rewards.append(def_score)\n",
716
  " print(f' defender mean reward (round {rnd+1}): {def_score:.4f}')\n",
717
  "\n",
718
- " # Phase B: fraud ES vs current defender\n",
719
- " if rnd < N_ROUNDS - 1: # skip ES on last round (no defender update will follow)\n",
 
 
 
 
 
 
 
 
 
 
720
  " round_fraud_fits = []\n",
721
- " for es in range(ES_STEPS_PER_ROUND):\n",
722
- " info = fraud_agent.es_step(_defender_action)\n",
723
- " round_fraud_fits.append(info['mean_fraud_fitness'])\n",
724
- " print(f' ES step {es+1}/{ES_STEPS_PER_ROUND}: mean_fitness={info[\"mean_fraud_fitness\"]:.3f}'\n",
725
- " f' best={info[\"best_fraud_fitness\"]:.3f} theta={info[\"theta\"]}')\n",
 
 
 
 
 
 
 
 
 
 
 
 
726
  " fraud_round_fitness.append(float(np.mean(round_fraud_fits)) if round_fraud_fits else 0.0)\n",
727
  " fraud_theta_history.append(dict(fraud_agent.theta))\n",
728
  "\n",
729
  " # Exploitability gap: how much WORSE the defender does against trained\n",
730
- " # fraud vs. against neutral fraud (intensity=1, noise=0.05, pattern_rate=0.2).\n",
731
  " env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
732
  " baseline_def = quick_defender_eval()\n",
733
- " fraud_agent.apply() # restore trained fraud\n",
734
  " adv_def = quick_defender_eval()\n",
735
  " gap = float(baseline_def - adv_def)\n",
736
  " exploitability_log.append(gap)\n",
737
  " print(f' exploitability gap: baseline_def={baseline_def:.3f} vs adv_def={adv_def:.3f} -> gap={gap:.3f}')\n",
738
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
  "print('\\nCo-training finished.')\n",
 
 
 
740
  "print(' defender_round_rewards:', defender_round_rewards)\n",
741
- "print(' fraud_round_fitness: ', fraud_round_fitness)\n",
742
- "print(' exploitability_log: ', exploitability_log)\n",
743
  "\n",
744
  "# Aliases for downstream cells\n",
745
  "loss_history = loss_history_all\n",
@@ -809,13 +1534,25 @@
809
  "source": [
810
  "import matplotlib.pyplot as plt\n",
811
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
812
  "# 1. GRPO training reward (across all rounds)\n",
813
  "if reward_log:\n",
814
  " plt.figure(figsize=(8,4))\n",
815
  " plt.plot(reward_log, label='GRPO mean reward per logging step')\n",
816
  " plt.xlabel('Logging step (across all defender rounds)')\n",
817
  " plt.ylabel('Reward')\n",
818
- " plt.title('GRPO defender training reward')\n",
819
  " plt.legend()\n",
820
  " plt.tight_layout()\n",
821
  " plt.savefig('artifacts/grpo_reward_curve.png', dpi=140)\n",
@@ -833,6 +1570,22 @@
833
  " plt.savefig('artifacts/grpo_training_loss.png', dpi=140)\n",
834
  " plt.show()\n",
835
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  "# 3. Co-evolution: defender reward vs fraud fitness per round\n",
837
  "rounds_x = np.arange(1, len(defender_round_rewards) + 1)\n",
838
  "fig, ax1 = plt.subplots(figsize=(8,4))\n",
@@ -875,34 +1628,74 @@
875
  " plt.savefig('artifacts/fraud_theta_trajectory.png', dpi=140)\n",
876
  " plt.show()\n",
877
  "\n",
878
- "# 6. Before vs After\n",
879
- "labels = ['Random', 'Heuristic', 'Trained LLM']\n",
880
- "values = [baseline_random['mean_reward'], baseline_heuristic['mean_reward'], trained_eval['mean_reward']]\n",
881
- "plt.figure(figsize=(7,4))\n",
882
- "bars = plt.bar(labels, values, color=['#bbb','#88c','#4a8'])\n",
 
 
 
 
 
 
 
 
 
 
 
883
  "for b, v in zip(bars, values):\n",
884
  " plt.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.3f}', ha='center')\n",
885
  "plt.ylabel('Mean reward (frozen holdout)')\n",
886
- "plt.title('Before vs After Training (GRPO + co-evolving fraud)')\n",
887
  "plt.tight_layout()\n",
888
  "plt.savefig('artifacts/before_after_rewards.png', dpi=140)\n",
889
  "plt.show()\n",
890
  "\n",
891
- "# 7. Per risk-bucket\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  "buckets = ['low', 'medium', 'high']\n",
893
- "rand_b = [baseline_random['bucket_means'][b] for b in buckets]\n",
894
- "heur_b = [baseline_heuristic['bucket_means'][b] for b in buckets]\n",
895
- "trnd_b = [trained_eval['bucket_means'][b] for b in buckets]\n",
 
896
  "x = np.arange(len(buckets))\n",
897
- "w = 0.27\n",
898
- "plt.figure(figsize=(8,4))\n",
899
- "plt.bar(x - w, rand_b, width=w, label='Random', color='#bbb')\n",
900
- "plt.bar(x, heur_b, width=w, label='Heuristic', color='#88c')\n",
901
- "plt.bar(x + w, trnd_b, width=w, label='Trained LLM', color='#4a8')\n",
 
902
  "plt.xticks(x, [b.title()+' Risk' for b in buckets])\n",
903
  "plt.ylabel('Mean reward')\n",
904
  "plt.title('Per Risk-Bucket Reward (frozen holdout)')\n",
905
- "plt.legend()\n",
906
  "plt.tight_layout()\n",
907
  "plt.savefig('artifacts/per_bucket_rewards.png', dpi=140)\n",
908
  "plt.show()\n",
@@ -912,21 +1705,36 @@
912
  " 'model_id': MODEL_ID,\n",
913
  " 'quick_mode': QUICK_MODE,\n",
914
  " 'prompts_used': len(prompts),\n",
 
 
 
 
915
  " 'grpo_num_generations': GRPO_NUM_GENERATIONS,\n",
916
  " 'rollout_steps_per_reward': ROLLOUT_STEPS_PER_REWARD,\n",
917
  " 'n_rounds': N_ROUNDS,\n",
918
  " 'grpo_steps_per_round': GRPO_STEPS_PER_ROUND,\n",
919
  " 'es_steps_per_round': ES_STEPS_PER_ROUND,\n",
920
  " 'es_population': ES_POPULATION,\n",
 
 
 
 
 
 
 
 
 
921
  " 'baseline_random_mean_reward': baseline_random['mean_reward'],\n",
922
  " 'baseline_heuristic_mean_reward': baseline_heuristic['mean_reward'],\n",
923
- " 'trained_mean_reward': trained_eval['mean_reward'],\n",
924
- " 'reward_gain_vs_random': trained_eval['mean_reward'] - baseline_random['mean_reward'],\n",
925
- " 'reward_gain_vs_heuristic': trained_eval['mean_reward'] - baseline_heuristic['mean_reward'],\n",
 
926
  " 'per_bucket': {\n",
927
- " 'random': baseline_random['bucket_means'],\n",
928
- " 'heuristic': baseline_heuristic['bucket_means'],\n",
929
- " 'trained': trained_eval['bucket_means'],\n",
 
930
  " },\n",
931
  " 'defender_round_rewards': defender_round_rewards,\n",
932
  " 'fraud_round_fitness': fraud_round_fitness,\n",
@@ -936,9 +1744,10 @@
936
  " 'grpo_reward_curve': reward_log,\n",
937
  " 'grpo_loss_history': loss_history,\n",
938
  " 'eval_per_episode': {\n",
939
- " 'random': baseline_random['per_episode_mean'],\n",
940
- " 'heuristic': baseline_heuristic['per_episode_mean'],\n",
941
- " 'trained': trained_eval['per_episode_mean'],\n",
 
942
  " },\n",
943
  "}\n",
944
  "with open('artifacts/run_summary.json', 'w', encoding='utf-8') as f:\n",
 
13
  "\n",
14
  "### What's implemented\n",
15
  "\n",
16
+ "This notebook implements **true co-evolution** between two learning agents,\n",
17
+ "trained in **two stages** with a **curriculum ladder + PFSP league** to keep\n",
18
+ "RL stable:\n",
19
+ "\n",
20
+ "**Stage 1 SFT warm-start.** The defender LoRA is first SFT'd on\n",
21
+ "`(prompt → heuristic_action)` pairs so the model learns the JSON output format\n",
22
+ "and the basic risk→action prior. Without this, GRPO from a cold base model gets\n",
23
+ "a flat reward curve and a near-zero loss (no advantage signal between\n",
24
+ "completions in a group).\n",
25
+ "\n",
26
+ "**Stage 2 Ladder co-evolution (GRPO ES + League).**\n",
27
+ "\n",
28
+ "* **Defender LLM** — `unsloth/phi-3-mini-4k-instruct-bnb-4bit` (LoRA) trained\n",
29
+ " with **TRL GRPO** on Unsloth (4-bit base, fp16 LoRA — no `bf16` so it runs on\n",
30
+ " Colab T4 which has no bf16 support).\n",
31
+ " Reward comes from a deterministic **K-step rollout** in the env (not a single\n",
32
+ " noisy step). All `num_generations` completions in a GRPO group share the\n",
33
+ " **same seed** (via `/reset_seeded`) AND the prompts are **refreshed each round\n",
34
+ " under the current adversary** so prompt-obs and reward-obs are always aligned.\n",
35
+ "\n",
36
+ "* **Fraud agent** — a parametric policy with 3 continuous parameters\n",
37
+ " (`intensity`, `noise_boost`, `pattern_rate`) updated by **Evolution Strategies (ES)**\n",
38
+ " and *anchored* to one of three ladder rungs (easy / medium / hard).\n",
39
+ " *Optional upgrade*: set `USE_LLM_FRAUD=True` in cell 6 to swap the ES\n",
40
+ " policy for a **second LoRA on the same Phi-3 base** — a true dual-LLM\n",
41
+ " self-play setup where the fraud LoRA is GRPO-trained to OUTPUT adversary\n",
42
+ " parameter JSON (reward = `1 - defender_reward`). Default OFF so QUICK\n",
43
+ " stays fast; flip ON for the upgraded recipe at ~1.5× wall time and\n",
44
+ " ~2× base-model VRAM.\n",
45
+ "\n",
46
+ "* **LADDER + LEAGUE (research-backed stability fix).** Pure ES drift is unstable\n",
47
+ " — the defender catastrophically forgets early attack regimes once fraud-θ\n",
48
+ " drifts. We solve this with:\n",
49
+ " 1. **Curriculum rungs** (`LADDER_RUNGS`): the round schedule promotes the\n",
50
+ " fraud anchor easy → medium → hard, so the defender masters each regime\n",
51
+ " before the next.\n",
52
+ " 2. **PFSP league pool** (`LeagueLadder`): every settled rung's fraud-θ is\n",
53
+ " snapshotted into a pool. During ES, with prob `LEAGUE_PAST_SAMPLE_PROB`\n",
54
+ " a candidate is evaluated against a sampled *past* rung instead of the\n",
55
+ " current one — keeping pressure across the whole observed difficulty.\n",
56
+ "\n",
57
+ "Co-training loop (per round):\n",
58
  "```\n",
59
  "for round in range(N_ROUNDS):\n",
60
+ " rung = LADDER_RUNGS[ rung_for_round(round) ] # easy medium → hard\n",
61
+ " fraud_agent.theta = rung_anchor # ladder anchor\n",
62
+ " refresh_prompts_under_current_adversary() # FIX B: prompt/reward alignment\n",
63
+ " train_defender_GRPO(K_step_rollout, same_seed_per_group)\n",
64
+ " league.add(fraud_agent.theta) # snapshot rung\n",
65
+ " ES_step_with_PFSP_past_sampling(defender) # LeagueLadder.sample\n",
66
  "```\n",
67
  "\n",
68
+ "Critical alignment & stability fixes baked in:\n",
69
+ "* **FIX A** — adversary is reset to NEUTRAL before baseline eval so Random /\n",
70
+ " Heuristic numbers are not poisoned by leftover state from a previous run.\n",
71
+ "* **FIX B** — prompts are re-collected at the start of every round under the\n",
72
+ " CURRENT adversary so `env_reset_seeded(seed)` reproduces the EXACT obs the\n",
73
+ " prompt was made from. Without this, ES drift would silently misalign the\n",
74
+ " GRPO gradient.\n",
75
+ "* **FIX C** — multi-step rollout (`K=3`) reduces single-step reward variance\n",
76
+ " and trains the model on the immediate downstream consequences (chargebacks,\n",
77
+ " anti-gaming alerts) that matter at episode-eval time.\n",
78
+ "* **FIX D** — the bar plot now shows BOTH \"Trained vs Neutral\" (apples-to-apples\n",
79
+ " with baselines) AND \"Trained vs Co-evolved\" (robustness on the hardest fraud).\n",
80
+ "\n",
81
  "Why this matters:\n",
82
+ "* Single-step rewards are noisy → **K-step rollout** kills variance.\n",
83
+ "* Different start states per generation → **same-seed group** gives clean advantages.\n",
84
  "* Static adversary → defender plateaus → **learning fraud agent** keeps pressure escalating.\n",
85
+ "* Pure ES drift catastrophic forgetting → **ladder rungs + PFSP league** stabilise it.\n",
86
  "\n",
87
  "Pipeline:\n",
88
+ "1. Install deps (Unsloth + Unsloth-Zoo + TRL from GitHub)\n",
89
  "2. HF login (uses your HF credits)\n",
90
  "3. GPU sanity check + env health\n",
91
+ "4. Build prompt + obs dataset from live `/reset_seeded` calls\n",
92
+ "5. **FIX A**: reset adversary to neutral, then baseline eval (random + heuristic)\n",
93
+ "6. Initialise FraudPolicy + LeagueLadder\n",
94
+ "7. **Stage 1: SFT warm-start** on heuristic-labeled (prompt, action) pairs\n",
95
+ "8. **Stage 2: Ladder co-training loop** — rung curriculum + GRPO defender + ES fraud + league\n",
96
+ "9. Trained-policy eval (vs co-evolved fraud AND vs neutral fraud)\n",
97
+ "10. Plots:\n",
98
+ " - SFT warm-start loss\n",
99
+ " - GRPO training reward + loss\n",
100
+ " - Defender mean reward per round\n",
101
+ " - Fraud agent mean fitness per round\n",
102
+ " - Exploitability gap per round\n",
103
+ " - Fraud parameter trajectories\n",
104
+ " - **FIX D**: Before vs After (4 bars: Random / Heuristic / Trained-neutral / Trained-coevolved)\n",
105
+ " - **FIX D**: Per risk-bucket reward (4 bars × 3 buckets)\n",
106
+ "11. Save artifacts to `./artifacts` (incl. ladder rung schedule + league pool)\n",
107
  "\n",
108
  "Hackathon: OpenEnv (India 2026), Theme #4 — Self-Improvement.\n",
109
  "Space: https://huggingface.co/spaces/Pratap-K/SmartPayEnv"
 
119
  {
120
  "cell_type": "code",
121
  "execution_count": null,
122
+ "id": "177bf9d5",
123
  "metadata": {},
124
  "outputs": [],
125
  "source": [
126
  "!pip -q install --upgrade pip\n",
127
  "!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
128
+ "!pip -q install \"unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git\"\n",
129
  "!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
130
+ "!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests"
131
  ]
132
  },
133
  {
 
170
  "SEED = 42\n",
171
  "\n",
172
  "# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
173
+ "# value that still produces all plots + meaningful accuracy comparison.\n",
174
+ "# Approx wall time on a Colab T4: QUICK ~5-7 min, FULL ~15-22 min.\n",
175
  "\n",
176
  "# Co-evolution loop\n",
177
+ "N_ROUNDS = 3 if QUICK_MODE else 6 # >=3 so the ladder visits >=2 rungs\n",
178
  "GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
179
  "ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
180
  "ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
181
  "ES_SIGMA = 0.25 # exploration std for ES\n",
182
  "ES_LR = 0.4 # ES update rate\n",
183
  "\n",
184
+ "# Defender / GRPO\n",
185
  "PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
186
  "GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
187
+ "# K=3 multi-step rollout: with the per-round prompt refresh (Fix B) the env's\n",
188
+ "# adversary config matches the obs the prompt was generated from, so K\n",
189
+ "# subsequent deterministic steps are well-defined. K>1 here reduces single-\n",
190
+ "# step reward variance and trains the model to pick actions that are also\n",
191
+ "# robust to the immediate downstream consequences (chargebacks, anti-gaming\n",
192
+ "# alerts) which matter at episode-eval time. Don't push K higher in QUICK\n",
193
+ "# (each generation costs K env round-trips).\n",
194
+ "ROLLOUT_STEPS_PER_REWARD = 3 if QUICK_MODE else 4\n",
195
  "\n",
196
  "# Final frozen-holdout eval\n",
197
  "EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
 
202
  "COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
203
  "COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
204
  "\n",
205
+ "# Token budgets (bumped after diagnosing prompt right-truncation dropping the\n",
206
+ "# schema instruction, and completion truncation cutting valid JSON mid-string).\n",
207
+ "DEF_MAX_PROMPT_TOKENS = 1024 if QUICK_MODE else 1536\n",
208
+ "DEF_MAX_NEW_TOKENS = 64 if QUICK_MODE else 96\n",
209
+ "\n",
210
+ "MODEL_ID = 'unsloth/phi-3-mini-4k-instruct-bnb-4bit'\n",
211
+ "MAX_SEQ_LEN = 2048 # ample for prompt + completion in both modes (phi-3 supports 4k)\n",
212
  "LOAD_IN_4BIT = True\n",
213
  "\n",
214
+ "# Disjoint seed range for training prompts so it never collides with eval seeds\n",
215
+ "# (10_000+ for fraud-vs-defender, 20_000+ for quick eval). The PROMPT_BASE_SEED\n",
216
+ "# is offset per round so each round's prompt set is fresh under the new adversary.\n",
217
+ "PROMPT_BASE_SEED = 1_000_000\n",
218
+ "\n",
219
+ "# ── Curriculum LADDER (PFSP-style league of fraud rungs) ─────────────\n",
220
+ "# Each rung is an anchor (intensity, noise_boost, pattern_rate) for the fraud\n",
221
+ "# agent. The defender starts at rung 0 (easy fraud) and climbs as rounds\n",
222
+ "# progress. ES still explores LOCALLY around each rung's anchor, so within a\n",
223
+ "# rung fraud gets harder against the current defender, then promotes. This\n",
224
+ "# is the curriculum-learning analogue of Fictitious-Self-Play: by keeping\n",
225
+ "# the *anchor* explicit, defender doesn't catastrophically forget early\n",
226
+ "# attack regimes when ES drifts the adversary too far. A snapshot of each\n",
227
+ "# settled fraud-θ is saved into the LeagueLadder pool (cell 16), and a\n",
228
+ "# fraction of ES evals are done against a sampled past rung to prevent\n",
229
+ "# the defender from being \"tutored\" by an unrealistically easy current rung.\n",
230
+ "LADDER_RUNGS = [\n",
231
+ " {'intensity': 1.0, 'noise_boost': 0.05, 'pattern_rate': 0.15}, # rung 0: easy\n",
232
+ " {'intensity': 1.3, 'noise_boost': 0.18, 'pattern_rate': 0.35}, # rung 1: medium\n",
233
+ " {'intensity': 1.7, 'noise_boost': 0.32, 'pattern_rate': 0.55}, # rung 2: hard\n",
234
+ "]\n",
235
+ "LEAGUE_PAST_SAMPLE_PROB = 0.3 # P(ES eval against a past rung instead of current)\n",
236
+ "\n",
237
+ "# ── OPTIONAL: dual-LoRA fraud LLM (truly two-LLM self-play) ──────────\n",
238
+ "# When True, a SECOND LoRA on the same Phi-3 base is trained to PROPOSE\n",
239
+ "# adversary parameters (intensity / noise_boost / pattern_rate) via GRPO,\n",
240
+ "# replacing the parametric ES fraud agent inside the co-training loop.\n",
241
+ "# Default OFF so QUICK_MODE stays fast (2x base-model VRAM and ~1.5x wall\n",
242
+ "# time when ON). Both LoRAs share the same MODEL_ID.\n",
243
+ "USE_LLM_FRAUD = False\n",
244
+ "FRAUD_GRPO_STEPS_PER_ROUND = 2 if QUICK_MODE else 8\n",
245
+ "FRAUD_PROMPT_DATASET_SIZE = 8 if QUICK_MODE else 32\n",
246
+ "FRAUD_GRPO_NUM_GENERATIONS = 3 if QUICK_MODE else 4\n",
247
+ "FRAUD_MAX_PROMPT_TOKENS = 512 if QUICK_MODE else 768\n",
248
+ "FRAUD_MAX_NEW_TOKENS = 48\n",
249
+ "FRAUD_LORA_R = 8 # smaller than defender (smaller search space)\n",
250
+ "\n",
251
  "os.makedirs('artifacts', exist_ok=True)\n",
252
  "random.seed(SEED)\n",
253
  "np.random.seed(SEED)\n",
 
258
  " '| pop =', ES_POPULATION,\n",
259
  " '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
260
  " '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
261
+ " '| LADDER rungs =', len(LADDER_RUNGS),\n",
262
+ " '| USE_LLM_FRAUD =', USE_LLM_FRAUD,\n",
263
  " '| MODEL_ID =', MODEL_ID)"
264
  ]
265
  },
 
359
  " return None\n",
360
  "\n",
361
  "def rollout_reward(action, seed, difficulty=DIFFICULTY, k=ROLLOUT_STEPS_PER_REWARD):\n",
362
+ " \"\"\"Score `action` on the *exact* obs that `seed` reproduces.\n",
363
+ "\n",
364
+ " Critical: `seed` MUST come from PROMPT_TO_SEED (set up in cell 12) so that\n",
365
+ " env_reset_seeded(seed) regenerates the SAME transaction whose obs is in the\n",
366
+ " prompt. The first env_step then scores the action on THAT obs — the only\n",
367
+ " way GRPO's reward can be correlated with the prompt the model saw.\n",
368
+ "\n",
369
+ " K=1 is the semantically correct default. K>1 averages across SUBSEQUENT\n",
370
+ " transactions whose optimal action differs, which dilutes the signal. The\n",
371
+ " parameter is kept for backward compat / variance experimentation only.\"\"\"\n",
372
  " env_reset_seeded(seed, difficulty)\n",
373
  " rewards = []\n",
374
  " for _ in range(int(k)):\n",
 
439
  {
440
  "cell_type": "code",
441
  "execution_count": null,
442
+ "id": "0b9f60c5",
443
  "metadata": {},
444
  "outputs": [],
445
  "source": [
446
+ "def collect_prompts(n=PROMPT_DATASET_SIZE, difficulty=DIFFICULTY,\n",
447
+ " base_seed=PROMPT_BASE_SEED):\n",
448
+ " \"\"\"Collect (seed, prompt, obs) triples using *deterministic* seeded resets.\n",
449
+ "\n",
450
+ " Each prompt i is generated by `env_reset_seeded(seed=base_seed+i)`, so the\n",
451
+ " same call later in `rollout_reward` reproduces the EXACT same obs. This is\n",
452
+ " what makes GRPO's reward correlated with the prompt — without it, the env\n",
453
+ " is reset to an unrelated state and the gradient is essentially noise.\n",
454
+ " \"\"\"\n",
455
+ " prompts, obs_list, seeds = [], [], []\n",
456
+ " for i in range(int(n)):\n",
457
+ " s = int(base_seed + i)\n",
458
+ " obs = env_reset_seeded(seed=s, difficulty=difficulty)\n",
459
  " prompts.append(make_prompt(obs))\n",
460
+ " obs_list.append(copy.deepcopy(obs))\n",
461
+ " seeds.append(s)\n",
462
+ " return prompts, obs_list, seeds\n",
463
+ "\n",
464
+ "prompts, prompt_obs, prompt_seeds = collect_prompts()\n",
465
+ "\n",
466
+ "# ── prompt → seed lookup (keyed on the obs JSON, NOT the full prompt string) ──\n",
467
+ "# We key on the obs JSON only, so even if TRL wraps the prompt in a chat\n",
468
+ "# template or alters whitespace, the lookup still hits.\n",
469
+ "import re as _re\n",
470
+ "_OBS_JSON_RE = _re.compile(\n",
471
+ " r'SmartPayEnv observation:\\n(\\{.*?\\})\\nReturn one action JSON',\n",
472
+ " _re.DOTALL,\n",
473
+ ")\n",
474
  "\n",
475
+ "def _obs_key(prompt_text):\n",
476
+ " m = _OBS_JSON_RE.search(prompt_text or '')\n",
477
+ " return m.group(1) if m else (prompt_text or '')\n",
478
+ "\n",
479
+ "PROMPT_TO_SEED = {_obs_key(p): s for p, s in zip(prompts, prompt_seeds)}\n",
480
+ "PROMPT_TO_OBS = {_obs_key(p): o for p, o in zip(prompts, prompt_obs)}\n",
481
+ "\n",
482
+ "print('Prompts collected:', len(prompts),\n",
483
+ " '| obs cached:', len(prompt_obs),\n",
484
+ " '| seed lookup entries:', len(PROMPT_TO_SEED))\n",
485
+ "print('Example prompt:\\n', prompts[0][:300], '...')\n",
486
+ "\n",
487
+ "# Sanity: round-trip the first prompt through the env to confirm the seeded\n",
488
+ "# reset really does reproduce the obs in the prompt.\n",
489
+ "_check_obs = env_reset_seeded(seed=prompt_seeds[0], difficulty=DIFFICULTY)\n",
490
+ "_orig = prompt_obs[0]\n",
491
+ "_match_keys = ['amount', 'merchant_category', 'observed_fraud_risk',\n",
492
+ " 'time_of_day', 'transaction_velocity']\n",
493
+ "_ok = all(_check_obs.get(k) == _orig.get(k) for k in _match_keys)\n",
494
+ "print(f' seed→obs reproducibility check on {_match_keys}: '\n",
495
+ " f'{\"OK\" if _ok else \"MISMATCH (alignment fix will not help!)\"}')"
496
  ]
497
  },
498
  {
 
505
  {
506
  "cell_type": "code",
507
  "execution_count": null,
508
+ "id": "89f1d935",
509
  "metadata": {},
510
  "outputs": [],
511
  "source": [
 
558
  " fd = 0\n",
559
  " return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
560
  "\n",
561
+ "# ── FIX A — Reset env adversary to NEUTRAL before measuring baselines ──\n",
562
+ "# The HF Space is a long-running server: previous runs leave the adversary\n",
563
+ "# at hard settings (e.g. intensity=1.8, noise=0.4 from a finished co-evolution\n",
564
+ "# loop), which silently penalises the heuristic baseline of any subsequent\n",
565
+ "# run and makes the bar chart misleading. We pin the adversary to a defined\n",
566
+ "# neutral state here so baselines are reproducible across runs and directly\n",
567
+ "# comparable with `trained_eval_neutral` later.\n",
568
+ "print('[FIX A] Resetting adversary to neutral before baseline eval...')\n",
569
+ "env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
570
+ "\n",
571
  "baseline_random = eval_policy(random_policy)\n",
572
  "baseline_heuristic = eval_policy(heuristic_policy)\n",
573
  "print('Random baseline:', baseline_random['mean_reward'], baseline_random['bucket_means'])\n",
 
673
  " 'best_fraud_fitness': float(np.max(fitnesses)),\n",
674
  " }\n",
675
  "\n",
676
+ "class LeagueLadder:\n",
677
+ " \"\"\"A pool of past fraud-θ snapshots, one per settled rung.\n",
678
+ "\n",
679
+ " Inspired by AlphaStar's PFSP league. We use the league for **two**\n",
680
+ " correctly-typed purposes:\n",
681
+ "\n",
682
+ " 1. **Defender-side rehearsal** (during prompt refresh): with probability\n",
683
+ " `LEAGUE_PAST_SAMPLE_PROB` we collect this round's prompts under a\n",
684
+ " sampled PAST rung instead of the current rung. This forces the\n",
685
+ " defender's GRPO gradient to occasionally include earlier attack\n",
686
+ " regimes — preventing catastrophic forgetting as the ladder climbs.\n",
687
+ "\n",
688
+ " 2. **Final robustness telemetry**: at the end of training we measure the\n",
689
+ " trained defender against EVERY rung in the league. A robust policy\n",
690
+ " scores well on all rungs; an over-fit one only scores well on the\n",
691
+ " last. This is plotted in cell 22.\n",
692
+ "\n",
693
+ " NOTE: We deliberately do NOT mix past rungs into the fraud-ES gradient.\n",
694
+ " Doing so credits the candidate-θ perturbation with fitness measured\n",
695
+ " against an unrelated past θ, which adds noise to the ES estimate\n",
696
+ " instead of useful signal. Defender rehearsal is the correct place.\n",
697
+ " \"\"\"\n",
698
+ " def __init__(self):\n",
699
+ " self.rungs = [] # list of {'name': str, 'theta': dict}\n",
700
+ " def add(self, name, theta):\n",
701
+ " self.rungs.append({'name': str(name), 'theta': dict(theta)})\n",
702
+ " def sample_past(self):\n",
703
+ " \"\"\"Uniformly sample a strictly-past rung. League is updated *after*\n",
704
+ " GRPO at the end of each round, so at prompt-refresh time the league\n",
705
+ " already contains only past rounds — no exclusion needed. Returns\n",
706
+ " None if the league is empty (round 1).\"\"\"\n",
707
+ " if not self.rungs:\n",
708
+ " return None\n",
709
+ " return dict(random.choice(self.rungs)['theta'])\n",
710
+ " def __len__(self):\n",
711
+ " return len(self.rungs)\n",
712
+ "\n",
713
+ "league = LeagueLadder()\n",
714
+ "\n",
715
  "fraud_agent = FraudPolicy()\n",
716
  "fraud_agent.apply()\n",
717
+ "print('Fraud agent initialised with theta =', fraud_agent.theta)\n",
718
+ "print(f'League ladder ready (rungs configured: {len(LADDER_RUNGS)}, '\n",
719
+ " f'past-rehearsal prob: {LEAGUE_PAST_SAMPLE_PROB})')"
720
  ]
721
  },
722
  {
 
724
  "id": "5efe6c56",
725
  "metadata": {},
726
  "source": [
727
+ "## 8. SFT warm-start → Ladder Co-evolution (GRPO defender ⇄ ES fraud + League)\n",
728
+ "\n",
729
+ "GRPO from a *cold* base model gives a flat reward curve: the policy doesn't yet\n",
730
+ "emit valid action JSON, so all completions in a group earn nearly the same\n",
731
+ "reward zero group-relative advantage zero gradient (loss collapses to ~1e-6).\n",
732
+ "\n",
733
+ "Even after SFT solves that, pure ES on the fraud agent introduces a *second*\n",
734
+ "failure mode: fraud-θ drifts arbitrarily, the defender catastrophically forgets\n",
735
+ "how to handle earlier attack regimes, and the eval bar chart shows the trained\n",
736
+ "LLM losing to baselines on the hardest risk bucket. We solve this with a\n",
737
+ "**ladder + league** wrapped around the two-stage training.\n",
738
+ "\n",
739
+ "**Stage 1: SFT warm-start (heuristic imitation)**\n",
740
+ "Label each cached prompt with the *heuristic* action (`risk_bucket → Block /\n",
741
+ "3DS / Allow + best gateway`) and run a short SFT pass. After this the model:\n",
742
+ "- emits parseable JSON ~100% of the time,\n",
743
+ "- already beats random,\n",
744
+ "- gives GRPO a *non-degenerate* starting policy with reward variance.\n",
745
+ "\n",
746
+ "**Stage 2: Ladder co-evolution (per round)**\n",
747
+ "1. **Pick rung.** `_rung_for_round(rnd)` selects a `LADDER_RUNGS` anchor\n",
748
+ " (easy / medium / hard). On rung change, fraud-θ is reset to that anchor —\n",
749
+ " ES then explores LOCALLY around it instead of drifting arbitrarily.\n",
750
+ "2. **Refresh prompts (Fix B).** Re-collect the prompt set under the *current*\n",
751
+ " adversary so prompt-obs and reward-obs match exactly inside this round's\n",
752
+ " GRPO. Without this, prompts made under rung k-1 are silently scored under\n",
753
+ " rung k (different intensity/noise → different obs from the same seed) and\n",
754
+ " the GRPO gradient is misaligned.\n",
755
+ "3. **Defender phase (GRPO).** `GRPO_STEPS_PER_ROUND` gradient steps. Reward\n",
756
+ " for each completion is a **K-step rollout** with a **shared seed** across\n",
757
+ " the whole group → clean group-relative advantage.\n",
758
+ "4. **Snapshot to league.** Save fraud-θ for this rung into `LeagueLadder`.\n",
759
+ "5. **Fraud phase (ES + PFSP).** ES updates push fraud-θ toward perturbations\n",
760
+ " that *lower* defender reward — but with prob `LEAGUE_PAST_SAMPLE_PROB` a\n",
761
+ " candidate is evaluated against a sampled past rung instead of the current\n",
762
+ " one, preventing over-fit to the latest anchor.\n",
763
  "\n",
764
  "Reward signal flow (per defender generation):\n",
765
  "```\n",
766
+ "group_seed = PROMPT_TO_SEED[obs_in_prompt] # round-local cached seed\n",
767
  "for completion in group:\n",
768
  " action = parse_action(completion)\n",
769
+ " /reset_seeded(group_seed) # reproduces THE EXACT obs in the prompt\n",
770
+ " reward = mean( /step(action) for k in K ) # K=3 deterministic rollout\n",
771
  "```\n",
772
+ "All `num_generations` completions of one prompt share `group_seed`, so the env\n",
773
+ "is reset to the *same* starting obs for every completion — exactly the obs the\n",
774
+ "model saw in its prompt. The only thing varying inside a group is the action,\n",
775
+ "exactly what GRPO needs for a clean group-relative advantage.\n",
776
+ "\n",
777
+ "**Why prompt refresh + ladder anchors are critical:** previously prompts were\n",
778
+ "collected ONCE before the loop, but ES then changed the adversary every round.\n",
779
+ "`env_reset_seeded(seed)` produces a different obs once `_adv_intensity` /\n",
780
+ "`_adv_noise_boost` change, so the obs inside the prompt and the obs the action\n",
781
+ "was scored against drifted apart. Refreshing prompts each round + anchoring\n",
782
+ "fraud to a discrete rung kills both the alignment bug AND the ES-drift\n",
783
+ "forgetting problem at once.\n",
784
+ "\n",
785
+ "**Token budgets** are sized so that:\n",
786
+ "- The schema instruction at the END of the prompt is never truncated\n",
787
+ " (`tokenizer.truncation_side='left'` drops the legend at the front instead).\n",
788
+ "- The completion JSON fits comfortably even if the model writes a short\n",
789
+ " prose prefix.\n",
790
+ "\n",
791
+ "No `/simulate` is used anywhere. No `bf16` (T4 has no bf16 support; Unsloth\n",
792
+ "auto-picks fp16 for the 4-bit base + LoRA).\n",
793
+ "\n",
794
+ "### Optional: dual-LoRA fraud LLM (`USE_LLM_FRAUD = True`)\n",
795
+ "\n",
796
+ "When the flag is on, a SECOND LoRA on the same Phi-3 base is trained alongside\n",
797
+ "the defender. Its prompt summarises the current matchup (rung + current θ +\n",
798
+ "last defender reward) and it must emit a JSON proposal of (intensity,\n",
799
+ "noise_boost, pattern_rate). Reward = `1 - defender_reward` evaluated under the\n",
800
+ "proposed θ, so GRPO's group-relative advantage rewards proposals the current\n",
801
+ "defender is weakest against.\n",
802
+ "\n",
803
+ "Per-round flow when enabled:\n",
804
+ "```\n",
805
+ "fraud_llm.grpo_step(rung_idx)\n",
806
+ " -> build N prompts, all sharing the same match-summary\n",
807
+ " -> GRPO group of FRAUD_GRPO_NUM_GENERATIONS samples per prompt\n",
808
+ " -> reward each sample by pushing it as adversary θ + quick_defender_eval\n",
809
+ " -> after burst: greedy-decode best θ, push to env, sync into fraud_agent.theta\n",
810
+ "```\n",
811
+ "Downstream code (league snapshots, exploitability gap, eval) is identical —\n",
812
+ "the LLM-proposed θ flows through the SAME `fraud_agent.theta` channel that\n",
813
+ "ES used to write to."
814
  ]
815
  },
816
  {
 
822
  "source": [
823
  "from unsloth import FastLanguageModel\n",
824
  "from datasets import Dataset\n",
825
+ "from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer\n",
826
  "import hashlib, torch\n",
827
  "\n",
828
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
 
831
  " dtype=None,\n",
832
  " load_in_4bit=LOAD_IN_4BIT,\n",
833
  ")\n",
834
+ "# Phi-3 uses fused projections (qkv_proj, gate_up_proj) — different module\n",
835
+ "# names than Qwen/Llama. We list both Phi-3 names and the standard names\n",
836
+ "# so the same cell works if MODEL_ID is later swapped back.\n",
837
+ "_PHI3_MODULES = ['qkv_proj', 'o_proj', 'gate_up_proj', 'down_proj']\n",
838
+ "_QWEN_MODULES = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj']\n",
839
+ "_target_modules = _PHI3_MODULES if 'phi-3' in MODEL_ID.lower() else _QWEN_MODULES\n",
840
+ "print(f'LoRA target_modules ({MODEL_ID}): {_target_modules}')\n",
841
  "model = FastLanguageModel.get_peft_model(\n",
842
  " model,\n",
843
  " r=16,\n",
844
+ " target_modules=_target_modules,\n",
845
  " lora_alpha=32,\n",
846
  " lora_dropout=0.0,\n",
847
  " bias='none',\n",
 
850
  ")\n",
851
  "if tokenizer.pad_token is None:\n",
852
  " tokenizer.pad_token = tokenizer.eos_token\n",
853
+ "# CRITICAL: left-truncate so if the prompt overflows, we drop the LEGEND\n",
854
+ "# at the front and keep the schema instruction at the END. Without this,\n",
855
+ "# right-truncation silently drops \"Return one action JSON...\" and the model\n",
856
+ "# emits prose -> parse_action falls back -> zero advantage in the GRPO group.\n",
857
+ "tokenizer.truncation_side = 'left'\n",
858
+ "\n",
859
+ "# ── Optional dual-LoRA fraud LLM ──────────────────────────────────────\n",
860
+ "# When USE_LLM_FRAUD=True we load a SECOND base-model + LoRA dedicated to\n",
861
+ "# the fraud agent. Same MODEL_ID, separate weights/adapter so the two\n",
862
+ "# policies don't interfere. The fraud LoRA is smaller (FRAUD_LORA_R) since\n",
863
+ "# the fraud action space is just a 3-float JSON.\n",
864
+ "fraud_model = None\n",
865
+ "fraud_tokenizer = None\n",
866
+ "if USE_LLM_FRAUD:\n",
867
+ " print(f'\\n[USE_LLM_FRAUD=True] loading SECOND base+LoRA for the fraud agent...')\n",
868
+ " fraud_model, fraud_tokenizer = FastLanguageModel.from_pretrained(\n",
869
+ " model_name=MODEL_ID,\n",
870
+ " max_seq_length=MAX_SEQ_LEN,\n",
871
+ " dtype=None,\n",
872
+ " load_in_4bit=LOAD_IN_4BIT,\n",
873
+ " )\n",
874
+ " fraud_model = FastLanguageModel.get_peft_model(\n",
875
+ " fraud_model,\n",
876
+ " r=FRAUD_LORA_R,\n",
877
+ " target_modules=_target_modules,\n",
878
+ " lora_alpha=2 * FRAUD_LORA_R,\n",
879
+ " lora_dropout=0.0,\n",
880
+ " bias='none',\n",
881
+ " use_gradient_checkpointing='unsloth',\n",
882
+ " random_state=SEED + 1,\n",
883
+ " )\n",
884
+ " if fraud_tokenizer.pad_token is None:\n",
885
+ " fraud_tokenizer.pad_token = fraud_tokenizer.eos_token\n",
886
+ " fraud_tokenizer.truncation_side = 'left'\n",
887
+ " print(f' fraud-LLM ready (LoRA r={FRAUD_LORA_R}, separate from defender)')\n",
888
  "\n",
889
  "ds = Dataset.from_list([{'prompt': p} for p in prompts])\n",
890
  "print(ds)\n",
891
  "\n",
892
+ "# Token budgets (used by both SFT and GRPO below). Centralised in cell 6.\n",
893
+ "_DEF_MAX_PROMPT = DEF_MAX_PROMPT_TOKENS\n",
894
+ "_DEF_MAX_NEW = DEF_MAX_NEW_TOKENS\n",
895
+ "\n",
896
+ "# ── Stage 1: SFT warm-start on heuristic-labeled actions ──────────────\n",
897
+ "# Without this, GRPO sees ~zero advantage between completions (all of them\n",
898
+ "# fail to emit valid JSON) and the loss collapses to ~1e-6 with a flat\n",
899
+ "# reward curve. SFT teaches the FORMAT + the basic risk→action prior so\n",
900
+ "# GRPO has actual variance to optimise.\n",
901
+ "\n",
902
+ "SFT_STEPS = 20 if QUICK_MODE else 80\n",
903
+ "SFT_LR = 2e-4\n",
904
+ "\n",
905
+ "def _heuristic_completion(obs):\n",
906
+ " \"\"\"Expert label = heuristic policy action, serialised as compact JSON.\"\"\"\n",
907
+ " a = heuristic_policy(obs)\n",
908
+ " return json.dumps(a)\n",
909
+ "\n",
910
+ "# Build (prompt, completion) pairs. SFTTrainer concatenates them and trains\n",
911
+ "# the LM to predict completion tokens given prompt.\n",
912
+ "sft_records = [\n",
913
+ " {'prompt': p, 'completion': _heuristic_completion(o)}\n",
914
+ " for p, o in zip(prompts, prompt_obs)\n",
915
+ "]\n",
916
+ "sft_ds = Dataset.from_list(sft_records)\n",
917
+ "print('SFT dataset:', sft_ds, '| sample completion:', sft_records[0]['completion'])\n",
918
+ "\n",
919
+ "sft_cfg = SFTConfig(\n",
920
+ " output_dir='outputs/theme4_sft_warmstart',\n",
921
+ " per_device_train_batch_size=2,\n",
922
+ " gradient_accumulation_steps=2,\n",
923
+ " max_steps=SFT_STEPS,\n",
924
+ " learning_rate=SFT_LR,\n",
925
+ " logging_steps=2,\n",
926
+ " save_strategy='no',\n",
927
+ " report_to=[],\n",
928
+ " # bf16 intentionally NOT set: T4 GPUs (the Colab default) don't support\n",
929
+ " # bf16 and Unsloth handles dtype internally for the 4-bit base + fp16\n",
930
+ " # LoRA. Letting the trainer auto-pick avoids \"bf16 unsupported\" crashes.\n",
931
+ " max_length=_DEF_MAX_PROMPT + _DEF_MAX_NEW + 32,\n",
932
+ " packing=False,\n",
933
+ " # Newer TRL defaults `padding_free=True`, which then refuses to enforce\n",
934
+ " # `max_length` unless packing is on. We don't want packing (it'd glue\n",
935
+ " # different (prompt, heuristic_completion) pairs together and confuse\n",
936
+ " # `completion_only_loss=True`), so disable padding-free explicitly.\n",
937
+ " padding_free=False,\n",
938
+ " completion_only_loss=True, # don't waste loss on prompt tokens\n",
939
+ ")\n",
940
+ "sft_trainer = SFTTrainer(\n",
941
+ " model=model,\n",
942
+ " args=sft_cfg,\n",
943
+ " train_dataset=sft_ds,\n",
944
+ " processing_class=tokenizer,\n",
945
+ ")\n",
946
+ "print(f'\\n=== SFT warm-start: {SFT_STEPS} steps on {len(sft_ds)} (prompt, heuristic_action) pairs ===')\n",
947
+ "sft_trainer.train()\n",
948
+ "sft_loss_history = [h.get('loss') for h in sft_trainer.state.log_history if 'loss' in h]\n",
949
+ "print('SFT done. loss curve:', sft_loss_history)\n",
950
+ "\n",
951
  "# ── Reward fn: same-seed group + multi-step rollout ───────────────────\n",
952
  "_REWARD_DEBUG = {'calls': 0}\n",
953
  "\n",
 
961
  " return str(comp)\n",
962
  "\n",
963
  "def _seed_for_prompt(prompt_text):\n",
964
+ " \"\"\"Look up the seed used to generate this prompt's obs (cell 12). When\n",
965
+ " found, env_reset_seeded(seed) reproduces the EXACT obs in the prompt, so\n",
966
+ " the reward is for the action-on-prompt's-obs (the only meaningful signal).\n",
967
+ "\n",
968
+ " Falls back to a hash for unseen prompts (e.g. evaluation), but during\n",
969
+ " GRPO training every prompt should hit the cache.\"\"\"\n",
970
+ " key = _obs_key(prompt_text or '')\n",
971
+ " s = PROMPT_TO_SEED.get(key)\n",
972
+ " if s is not None:\n",
973
+ " return int(s)\n",
974
+ " h = hashlib.md5((prompt_text or '').encode('utf-8')).hexdigest()\n",
975
  " return int(h[:8], 16) & 0x7FFFFFFF\n",
976
  "\n",
977
  "def reward_fn(completions, prompts=None, **kwargs):\n",
978
+ " \"\"\"For each completion: parse action, score it on the PROMPT'S obs by\n",
979
+ " resetting the env to the cached seed for that prompt. All completions in\n",
980
+ " a GRPO group share the same prompt -> same seed -> same starting obs ->\n",
981
+ " only the action varies -> clean group-relative advantage.\n",
982
+ "\n",
983
+ " LEAGUE-AWARE: if the prompt was collected under a *past* rung (rehearsal\n",
984
+ " share), we re-apply that past θ to the env BEFORE the rollout so the\n",
985
+ " obs reproduces exactly. We then restore the global current adversary\n",
986
+ " after the batch (handled by the surrounding loop).\"\"\"\n",
987
  " rewards = []\n",
988
+ " parsed_actions = []\n",
989
+ " n_cache_hit = 0\n",
990
+ " n_past_rehearsal = 0\n",
991
  " prompts = prompts or [None] * len(completions)\n",
992
+ " last_theta_applied = None\n",
993
  " for prompt_text, comp in zip(prompts, completions):\n",
994
  " text = _extract_text(comp)\n",
995
  " action = parse_action(text)\n",
996
+ " parsed_actions.append(action)\n",
997
+ " key = _obs_key(prompt_text or '')\n",
998
  " seed = _seed_for_prompt(prompt_text or text)\n",
999
+ " if key in PROMPT_TO_SEED:\n",
1000
+ " n_cache_hit += 1\n",
1001
+ " # Re-apply the adversary the prompt was made under (only if it differs\n",
1002
+ " # from what we last applied — avoids spamming the env API).\n",
1003
+ " prompt_theta = PROMPT_TO_THETA.get(key)\n",
1004
+ " if prompt_theta is not None and prompt_theta != last_theta_applied:\n",
1005
+ " env_configure_adversary(**prompt_theta, strategy='mixed')\n",
1006
+ " last_theta_applied = prompt_theta\n",
1007
+ " if prompt_theta != _CURRENT_ROUND_THETA.get('theta'):\n",
1008
+ " n_past_rehearsal += 1\n",
1009
  " try:\n",
1010
  " r = rollout_reward(action, seed=seed, difficulty=DIFFICULTY,\n",
1011
  " k=ROLLOUT_STEPS_PER_REWARD)\n",
 
1013
  " print('reward_fn error:', repr(e))\n",
1014
  " r = 0.0\n",
1015
  " rewards.append(float(r))\n",
1016
+ " # Restore current round's adversary after the batch so ES + quick eval\n",
1017
+ " # next called sees the canonical state.\n",
1018
+ " cur = _CURRENT_ROUND_THETA.get('theta')\n",
1019
+ " if cur is not None and cur != last_theta_applied:\n",
1020
+ " env_configure_adversary(**cur, strategy='mixed')\n",
1021
  " _REWARD_DEBUG['calls'] += 1\n",
1022
  " if _REWARD_DEBUG['calls'] <= 3:\n",
1023
+ " n_unique_actions = len({tuple(sorted(a.items())) for a in parsed_actions})\n",
1024
+ " n_unique_rewards = len({round(r, 4) for r in rewards})\n",
1025
+ " print(f\"[reward_fn batch {_REWARD_DEBUG['calls']}] \"\n",
1026
+ " f\"cache_hits={n_cache_hit}/{len(completions)} \"\n",
1027
+ " f\"past_rehearsal_reapplies={n_past_rehearsal} \"\n",
1028
+ " f\"unique_actions={n_unique_actions} \"\n",
1029
+ " f\"unique_rewards={n_unique_rewards} \"\n",
1030
+ " f\"reward_std={float(np.std(rewards)):.4f} \"\n",
1031
+ " f\"sample={rewards[:6]}\")\n",
1032
  " return rewards\n",
1033
  "\n",
1034
+ "# Tracks the round's \"current\" θ so reward_fn can restore it after a\n",
1035
+ "# rehearsal-sample reapply. Populated by the loop below.\n",
1036
+ "_CURRENT_ROUND_THETA = {'theta': None}\n",
1037
+ "\n",
1038
  "# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
1039
+ "# Token budgets are big enough to (a) NOT truncate the schema instruction at\n",
1040
+ "# the end of the prompt and (b) safely fit a JSON action even if the model\n",
1041
+ "# writes a short prose prefix. With tokenizer.truncation_side='left' set\n",
1042
+ "# above, any overflow drops the legend at the front (lowest-value tokens),\n",
1043
+ "# never the schema instruction at the end.\n",
1044
  "\n",
1045
  "@torch.no_grad()\n",
1046
  "def _defender_action(obs):\n",
 
1057
  " FastLanguageModel.for_training(model)\n",
1058
  " return parse_action(text)\n",
1059
  "\n",
1060
+ "# ── Post-SFT sanity: the warm-started model should now agree with the\n",
1061
+ "# heuristic on most prompts. If it doesn't, GRPO will still help, but\n",
1062
+ "# this is the cheapest signal that SFT actually moved the policy.\n",
1063
+ "_warm_match = 0\n",
1064
+ "_warm_n = min(8, len(prompt_obs))\n",
1065
+ "for _o in prompt_obs[:_warm_n]:\n",
1066
+ " _a_model = _defender_action(_o)\n",
1067
+ " _a_heur = heuristic_policy(_o)\n",
1068
+ " if _a_model == _a_heur:\n",
1069
+ " _warm_match += 1\n",
1070
+ "print(f' SFT sanity: model matches heuristic on {_warm_match}/{_warm_n} sample obs')\n",
1071
+ "\n",
1072
  "# ── GRPO config (per-round) ───────────────────────────────────────────\n",
1073
  "def _make_grpo_cfg(max_steps):\n",
1074
  " return GRPOConfig(\n",
 
1080
  " gradient_accumulation_steps=2,\n",
1081
  " max_steps=int(max_steps),\n",
1082
  " logging_steps=1,\n",
1083
+ " learning_rate=5e-6, # lower than 1e-5 to keep close to SFT prior\n",
1084
  " save_strategy='no',\n",
1085
  " report_to=[],\n",
1086
+ " # bf16 intentionally NOT set — T4 has no bf16 support; Unsloth picks\n",
1087
+ " # the right dtype automatically based on the loaded 4-bit base model.\n",
1088
+ " temperature=1.1, # slight bump so post-SFT logits explore\n",
1089
+ " beta=0.04, # stronger KL: don't drift from SFT'd policy\n",
1090
  " )\n",
1091
  "\n",
1092
  "# ── Co-training loop ──────────────────────────────────────────────────\n",
 
1096
  "fraud_theta_history = [dict(fraud_agent.theta)]\n",
1097
  "loss_history_all = []\n",
1098
  "reward_log_all = []\n",
1099
+ "ladder_round_rung = [] # which ladder rung each round trained against\n",
1100
  "\n",
1101
  "# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
1102
  "# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
 
1113
  " obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
1114
  " return float(np.mean(rs)) if rs else 0.0\n",
1115
  "\n",
1116
+ "def _refresh_prompts_for_round(rnd_idx, current_theta):\n",
1117
+ " \"\"\"FIX B + League rehearsal — re-collect prompts so prompt-obs and\n",
1118
+ " reward-obs match exactly inside this round's GRPO.\n",
1119
+ "\n",
1120
+ " LADDER + LEAGUE TWIST: a fraction `LEAGUE_PAST_SAMPLE_PROB` of prompts\n",
1121
+ " are collected under a *sampled past rung* instead of the current rung.\n",
1122
+ " Crucially, the env's adversary is restored to the CURRENT rung after\n",
1123
+ " refresh — but the prompts collected under the past rung carry an obs\n",
1124
+ " that wouldn't exist under the current adversary. To keep alignment\n",
1125
+ " perfect, we ONLY use the past rung for prompts whose REWARD will also\n",
1126
+ " be computed under that rung. We accomplish this by:\n",
1127
+ " (a) splitting the prompt set into 'current' and 'past' shards,\n",
1128
+ " (b) computing all 'current' prompts first, then ES-time-temporarily\n",
1129
+ " applying the past rung to compute 'past' prompts,\n",
1130
+ " (c) restoring the current rung at the end, and\n",
1131
+ " (d) tagging each prompt's seed with the adversary it was made under,\n",
1132
+ " so reward_fn can re-apply that adversary before scoring.\n",
1133
+ "\n",
1134
+ " For QUICK_MODE (3 rounds) the past pool only fills from round 2 onward,\n",
1135
+ " so round 0 always uses 100% current rung.\n",
1136
+ "\n",
1137
+ " Returns: (Dataset, prompts_list, obs_list).\n",
1138
+ " \"\"\"\n",
1139
+ " base = PROMPT_BASE_SEED + rnd_idx * PROMPT_DATASET_SIZE * 13\n",
1140
+ "\n",
1141
+ " # Decide how many prompts come from a past rung (rehearsal share).\n",
1142
+ " n_past = 0\n",
1143
+ " past_theta = None\n",
1144
+ " if len(league) >= 1:\n",
1145
+ " past_theta = league.sample_past()\n",
1146
+ " if past_theta is not None:\n",
1147
+ " n_past = int(round(PROMPT_DATASET_SIZE * LEAGUE_PAST_SAMPLE_PROB))\n",
1148
+ " n_current = PROMPT_DATASET_SIZE - n_past\n",
1149
+ "\n",
1150
+ " # Phase 1 — current rung prompts\n",
1151
+ " env_configure_adversary(**current_theta, strategy='mixed')\n",
1152
+ " cur_prompts, cur_obs, cur_seeds = collect_prompts(n=n_current, base_seed=base)\n",
1153
+ " cur_theta_per_seed = {s: dict(current_theta) for s in cur_seeds}\n",
1154
+ "\n",
1155
+ " # Phase 2 — past rung rehearsal prompts (if any)\n",
1156
+ " past_prompts, past_obs, past_seeds = [], [], []\n",
1157
+ " past_theta_per_seed = {}\n",
1158
+ " if n_past > 0 and past_theta is not None:\n",
1159
+ " env_configure_adversary(**past_theta, strategy='mixed')\n",
1160
+ " past_prompts, past_obs, past_seeds = collect_prompts(\n",
1161
+ " n=n_past, base_seed=base + 7919 # disjoint sub-range\n",
1162
+ " )\n",
1163
+ " past_theta_per_seed = {s: dict(past_theta) for s in past_seeds}\n",
1164
+ "\n",
1165
+ " # Restore current rung as the env's \"default\" — reward_fn will re-apply\n",
1166
+ " # the per-seed θ before each rollout (see PROMPT_TO_THETA below).\n",
1167
+ " env_configure_adversary(**current_theta, strategy='mixed')\n",
1168
+ "\n",
1169
+ " # Combine\n",
1170
+ " new_prompts = cur_prompts + past_prompts\n",
1171
+ " new_obs = cur_obs + past_obs\n",
1172
+ " new_seeds = cur_seeds + past_seeds\n",
1173
+ " new_theta_per_seed = {**cur_theta_per_seed, **past_theta_per_seed}\n",
1174
+ "\n",
1175
+ " PROMPT_TO_SEED.clear()\n",
1176
+ " PROMPT_TO_SEED.update({_obs_key(p): s for p, s in zip(new_prompts, new_seeds)})\n",
1177
+ " PROMPT_TO_OBS.clear()\n",
1178
+ " PROMPT_TO_OBS.update({_obs_key(p): o for p, o in zip(new_prompts, new_obs)})\n",
1179
+ " PROMPT_TO_THETA.clear()\n",
1180
+ " PROMPT_TO_THETA.update({_obs_key(p): new_theta_per_seed[s]\n",
1181
+ " for p, s in zip(new_prompts, new_seeds)})\n",
1182
+ "\n",
1183
+ " print(f' [FIX B + league] refreshed {len(new_prompts)} prompts: '\n",
1184
+ " f'{n_current} current rung + {n_past} past rung (rehearsal)')\n",
1185
+ " return Dataset.from_list([{'prompt': p} for p in new_prompts]), new_prompts, new_obs\n",
1186
+ "\n",
1187
+ "# ── Per-prompt theta lookup so reward_fn can re-apply the adversary the\n",
1188
+ "# prompt was made under (essential for league rehearsal to stay aligned).\n",
1189
+ "PROMPT_TO_THETA = {}\n",
1190
+ "\n",
1191
+ "def _rung_for_round(rnd_idx):\n",
1192
+ " \"\"\"Distribute ladder rungs evenly across rounds. With N_ROUNDS=3 + 3 rungs\n",
1193
+ " we get rounds [0,1,2] -> rungs [0,1,2]. With N_ROUNDS=6 + 3 rungs we get\n",
1194
+ " rounds [0,1,2,3,4,5] -> rungs [0,0,1,1,2,2].\"\"\"\n",
1195
+ " return min(rnd_idx * len(LADDER_RUNGS) // max(N_ROUNDS, 1), len(LADDER_RUNGS) - 1)\n",
1196
+ "\n",
1197
+ "# ── OPTIONAL: dual-LoRA fraud LLM policy ─────────────────────────────\n",
1198
+ "# When USE_LLM_FRAUD=True, this replaces FraudPolicy.es_step inside the\n",
1199
+ "# co-training loop. It is a SECOND LoRA on the same Phi-3 base, trained\n",
1200
+ "# with TRL GRPO to OUTPUT adversary-parameter JSON. Reward = 1 - defender_reward\n",
1201
+ "# under the proposed θ, so the GRPO group-relative advantage rewards the\n",
1202
+ "# fraud LLM for proposing thetas the current defender is weakest against.\n",
1203
+ "#\n",
1204
+ "# Why this is the right structural upgrade (vs. e.g. fraud LLM emitting\n",
1205
+ "# raw transaction JSON): it reuses the existing /configure_adversary +\n",
1206
+ "# quick_defender_eval pipeline, so we don't need any new env endpoints —\n",
1207
+ "# the fraud LLM's \"action\" is exactly the same dict that ES manipulates.\n",
1208
+ "\n",
1209
+ "_FRAUD_KEYS = ('intensity', 'noise_boost', 'pattern_rate')\n",
1210
+ "\n",
1211
+ "def _fraud_summary_text(rung_idx, current_theta, last_def_score):\n",
1212
+ " \"\"\"Compact, model-friendly summary of the current matchup that the fraud\n",
1213
+ " LLM conditions on. Kept short so the prompt stays under FRAUD_MAX_PROMPT_TOKENS.\"\"\"\n",
1214
+ " return (\n",
1215
+ " f'rung={rung_idx}'\n",
1216
+ " f' | current_theta={ {k: round(current_theta[k], 3) for k in _FRAUD_KEYS} }'\n",
1217
+ " f' | last_defender_reward={last_def_score:.3f}'\n",
1218
+ " f' | bounds={ {k: list(FRAUD_PARAM_BOUNDS[k]) for k in _FRAUD_KEYS} }'\n",
1219
+ " )\n",
1220
+ "\n",
1221
+ "def make_fraud_prompt(summary):\n",
1222
+ " return (\n",
1223
+ " 'You design adversary parameters for a payments env. The defender LLM\\n'\n",
1224
+ " 'is an LLM-based fraud detector. Your goal: pick (intensity, noise_boost,\\n'\n",
1225
+ " 'pattern_rate) so the defender\\'s reward is MINIMISED while staying inside\\n'\n",
1226
+ " 'the bounds. Higher intensity = harder fraud, higher noise_boost = stealthier\\n'\n",
1227
+ " 'risk score, higher pattern_rate = more bursty attacks.\\n'\n",
1228
+ " f'Match summary: {summary}\\n'\n",
1229
+ " 'Return ONE JSON: {\"intensity\": <float>, \"noise_boost\": <float>, \"pattern_rate\": <float>}.'\n",
1230
+ " )\n",
1231
+ "\n",
1232
+ "_FRAUD_JSON_RE = re.compile(r'\\{[^{}]*\\}')\n",
1233
+ "\n",
1234
+ "def parse_fraud_theta(text, default_theta):\n",
1235
+ " \"\"\"Extract {intensity, noise_boost, pattern_rate} JSON, fall back to the\n",
1236
+ " given default + clip to bounds. Same defensive pattern as parse_action.\"\"\"\n",
1237
+ " m = _FRAUD_JSON_RE.search(text or '')\n",
1238
+ " if not m:\n",
1239
+ " return _clip_theta(dict(default_theta))\n",
1240
+ " try:\n",
1241
+ " raw = json.loads(m.group(0))\n",
1242
+ " out = dict(default_theta)\n",
1243
+ " for k in _FRAUD_KEYS:\n",
1244
+ " if k in raw:\n",
1245
+ " out[k] = float(raw[k])\n",
1246
+ " return _clip_theta(out)\n",
1247
+ " except Exception:\n",
1248
+ " return _clip_theta(dict(default_theta))\n",
1249
+ "\n",
1250
+ "class FraudLLMPolicy:\n",
1251
+ " \"\"\"Dual-LoRA fraud agent: an LLM that proposes adversary θ via GRPO.\n",
1252
+ " Replaces FraudPolicy.es_step when USE_LLM_FRAUD=True.\"\"\"\n",
1253
+ " def __init__(self, fmodel, ftokenizer, defender_fn, current_theta_fn):\n",
1254
+ " self.model = fmodel\n",
1255
+ " self.tokenizer = ftokenizer\n",
1256
+ " self.defender_fn = defender_fn\n",
1257
+ " self.current_theta_fn = current_theta_fn # ()->dict, latest θ\n",
1258
+ " self.last_def_score = 0.5\n",
1259
+ " self.loss_history = []\n",
1260
+ " self.reward_history = []\n",
1261
+ " self.theta_history = []\n",
1262
+ "\n",
1263
+ " @torch.no_grad()\n",
1264
+ " def _generate_one(self, summary):\n",
1265
+ " FastLanguageModel.for_inference(self.model)\n",
1266
+ " device = next(self.model.parameters()).device\n",
1267
+ " prompt = make_fraud_prompt(summary)\n",
1268
+ " inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True,\n",
1269
+ " max_length=FRAUD_MAX_PROMPT_TOKENS).to(device)\n",
1270
+ " out = self.model.generate(\n",
1271
+ " **inputs, max_new_tokens=FRAUD_MAX_NEW_TOKENS, do_sample=False,\n",
1272
+ " pad_token_id=self.tokenizer.pad_token_id,\n",
1273
+ " )\n",
1274
+ " text = self.tokenizer.decode(out[0][inputs['input_ids'].shape[1]:],\n",
1275
+ " skip_special_tokens=True)\n",
1276
+ " FastLanguageModel.for_training(self.model)\n",
1277
+ " return parse_fraud_theta(text, self.current_theta_fn())\n",
1278
+ "\n",
1279
+ " def grpo_step(self, rung_idx):\n",
1280
+ " \"\"\"One GRPO burst: build a tiny prompt set conditioned on the current\n",
1281
+ " match summary, train fraud LoRA to output θ with reward = 1 - defender_reward.\"\"\"\n",
1282
+ " cur_theta = self.current_theta_fn()\n",
1283
+ " # All prompts in the burst share the same summary (it doesn't change\n",
1284
+ " # within a single ES-replacement step). num_generations supplies the\n",
1285
+ " # group-relative variance via sampling, exactly like defender GRPO.\n",
1286
+ " summary = _fraud_summary_text(rung_idx, cur_theta, self.last_def_score)\n",
1287
+ " prompt = make_fraud_prompt(summary)\n",
1288
+ " ds_fraud = Dataset.from_list(\n",
1289
+ " [{'prompt': prompt} for _ in range(FRAUD_PROMPT_DATASET_SIZE)]\n",
1290
+ " )\n",
1291
+ "\n",
1292
+ " def fraud_reward_fn(completions, prompts=None, **_):\n",
1293
+ " rewards = []\n",
1294
+ " for comp in completions:\n",
1295
+ " text = (comp if isinstance(comp, str)\n",
1296
+ " else (comp[0].get('content','') if isinstance(comp, list)\n",
1297
+ " else comp.get('content','')))\n",
1298
+ " proposed = parse_fraud_theta(text, cur_theta)\n",
1299
+ " # Push proposal to env, measure defender reward under it.\n",
1300
+ " env_configure_adversary(**proposed, strategy='mixed')\n",
1301
+ " def_score = quick_defender_eval()\n",
1302
+ " rewards.append(float(1.0 - def_score)) # fraud wants low def_reward\n",
1303
+ " # Restore current θ so the OUTER loop's next call sees canonical state.\n",
1304
+ " env_configure_adversary(**cur_theta, strategy='mixed')\n",
1305
+ " return rewards\n",
1306
+ "\n",
1307
+ " cfg = GRPOConfig(\n",
1308
+ " output_dir='outputs/theme4_fraud_grpo',\n",
1309
+ " num_generations=FRAUD_GRPO_NUM_GENERATIONS,\n",
1310
+ " max_prompt_length=FRAUD_MAX_PROMPT_TOKENS,\n",
1311
+ " max_completion_length=FRAUD_MAX_NEW_TOKENS,\n",
1312
+ " per_device_train_batch_size=1,\n",
1313
+ " gradient_accumulation_steps=2,\n",
1314
+ " max_steps=int(FRAUD_GRPO_STEPS_PER_ROUND),\n",
1315
+ " logging_steps=1,\n",
1316
+ " learning_rate=5e-6,\n",
1317
+ " save_strategy='no',\n",
1318
+ " report_to=[],\n",
1319
+ " temperature=1.1,\n",
1320
+ " beta=0.04,\n",
1321
+ " )\n",
1322
+ " trainer = GRPOTrainer(\n",
1323
+ " model=self.model, args=cfg, train_dataset=ds_fraud,\n",
1324
+ " processing_class=self.tokenizer, reward_funcs=[fraud_reward_fn],\n",
1325
+ " )\n",
1326
+ " trainer.train()\n",
1327
+ " self.loss_history.extend(\n",
1328
+ " [h.get('loss') for h in trainer.state.log_history if 'loss' in h]\n",
1329
+ " )\n",
1330
+ " self.reward_history.extend(\n",
1331
+ " [h.get('reward') for h in trainer.state.log_history if 'reward' in h]\n",
1332
+ " )\n",
1333
+ "\n",
1334
+ " # Greedy generation = the LoRA's \"best guess\" θ after this burst.\n",
1335
+ " new_theta = self._generate_one(summary)\n",
1336
+ " self.theta_history.append(dict(new_theta))\n",
1337
+ " env_configure_adversary(**new_theta, strategy='mixed')\n",
1338
+ " # Refresh last-defender-score under the chosen θ (used in the NEXT\n",
1339
+ " # round's summary) so the fraud LLM gets a calibrated signal.\n",
1340
+ " self.last_def_score = float(quick_defender_eval())\n",
1341
+ " return {'theta': new_theta, 'def_reward_under_new_theta': self.last_def_score}\n",
1342
+ "\n",
1343
+ "# Instantiate fraud LLM policy ONCE if enabled. Defender_fn is set later\n",
1344
+ "# (closures capture the latest defender LoRA each call automatically).\n",
1345
+ "fraud_llm = None\n",
1346
+ "if USE_LLM_FRAUD and fraud_model is not None:\n",
1347
+ " fraud_llm = FraudLLMPolicy(\n",
1348
+ " fmodel=fraud_model,\n",
1349
+ " ftokenizer=fraud_tokenizer,\n",
1350
+ " defender_fn=_defender_action,\n",
1351
+ " current_theta_fn=lambda: dict(fraud_agent.theta),\n",
1352
+ " )\n",
1353
+ " print(f'[USE_LLM_FRAUD] FraudLLMPolicy ready '\n",
1354
+ " f'(GRPO steps/round={FRAUD_GRPO_STEPS_PER_ROUND}, '\n",
1355
+ " f'num_generations={FRAUD_GRPO_NUM_GENERATIONS})')\n",
1356
  "\n",
1357
  "for rnd in range(N_ROUNDS):\n",
1358
+ " rung_idx = _rung_for_round(rnd)\n",
1359
+ " rung_anchor = LADDER_RUNGS[rung_idx]\n",
1360
+ " ladder_round_rung.append(rung_idx)\n",
1361
+ " print(f'\\n=== Round {rnd+1}/{N_ROUNDS} | LADDER RUNG {rung_idx} ({rung_anchor}) ===')\n",
1362
+ "\n",
1363
+ " # Anchor the fraud agent at this rung's defaults at the START of the round\n",
1364
+ " # (only on rung CHANGE — within a rung, ES keeps drifting locally).\n",
1365
+ " if rnd == 0 or rung_idx != _rung_for_round(rnd - 1):\n",
1366
+ " fraud_agent.theta = dict(rung_anchor)\n",
1367
+ " fraud_agent.history.append(dict(fraud_agent.theta))\n",
1368
+ " fraud_theta_history.append(dict(fraud_agent.theta))\n",
1369
+ " print(f' ladder anchor applied: θ <- {fraud_agent.theta}')\n",
1370
+ " fraud_agent.apply()\n",
1371
+ " print(f' current fraud θ: {fraud_agent.theta}')\n",
1372
+ "\n",
1373
+ " # Track current-round θ so reward_fn knows what to restore between\n",
1374
+ " # rehearsal-sample reapplies.\n",
1375
+ " _CURRENT_ROUND_THETA['theta'] = dict(fraud_agent.theta)\n",
1376
+ "\n",
1377
+ " # FIX B + LEAGUE rehearsal — refresh prompts under the CURRENT adversary\n",
1378
+ " # (and a `LEAGUE_PAST_SAMPLE_PROB` share under a sampled past rung, with\n",
1379
+ " # per-prompt θ recorded so reward_fn can re-apply it correctly).\n",
1380
+ " ds_round, prompts_round, prompt_obs_round = _refresh_prompts_for_round(\n",
1381
+ " rnd, current_theta=fraud_agent.theta\n",
1382
+ " )\n",
1383
  "\n",
1384
+ " # Phase A: defender GRPO on this round's freshly-aligned prompts.\n",
1385
  " cfg = _make_grpo_cfg(max_steps=GRPO_STEPS_PER_ROUND)\n",
1386
  " trainer = GRPOTrainer(\n",
1387
+ " model=model, args=cfg, train_dataset=ds_round,\n",
1388
  " processing_class=tokenizer, reward_funcs=[reward_fn],\n",
1389
  " )\n",
1390
  " trainer.train()\n",
 
1393
  " loss_history_all.extend(rnd_loss)\n",
1394
  " reward_log_all.extend(rnd_rew)\n",
1395
  "\n",
1396
+ " # Make sure env is back at current rung after GRPO before quick_eval.\n",
1397
+ " fraud_agent.apply()\n",
1398
  " def_score = quick_defender_eval()\n",
1399
  " defender_round_rewards.append(def_score)\n",
1400
  " print(f' defender mean reward (round {rnd+1}): {def_score:.4f}')\n",
1401
  "\n",
1402
+ " # Snapshot settled fraud at this rung into the league (used by next\n",
1403
+ " # round's prompt rehearsal share).\n",
1404
+ " league.add(name=f'round{rnd+1}_rung{rung_idx}', theta=fraud_agent.theta)\n",
1405
+ " print(f' league snapshot taken: now {len(league)} rung(s) in pool')\n",
1406
+ "\n",
1407
+ " # Phase B: fraud update vs current defender.\n",
1408
+ " # USE_LLM_FRAUD=False (default) -> parametric ES on FraudPolicy\n",
1409
+ " # USE_LLM_FRAUD=True -> GRPO on the fraud LoRA (FraudLLMPolicy)\n",
1410
+ " # In both cases the resulting θ is pushed to the env via /configure_adversary\n",
1411
+ " # and `fraud_agent.theta` is kept in sync so downstream code (snapshots,\n",
1412
+ " # exploitability gap, eval) remains identical.\n",
1413
+ " if rnd < N_ROUNDS - 1:\n",
1414
  " round_fraud_fits = []\n",
1415
+ " if USE_LLM_FRAUD and fraud_llm is not None:\n",
1416
+ " # Fraud LLM does ONE GRPO burst per round (FRAUD_GRPO_STEPS_PER_ROUND\n",
1417
+ " # steps inside it). Mirror θ back into fraud_agent so later code\n",
1418
+ " # (which still queries fraud_agent.theta) sees the new value.\n",
1419
+ " print(f' [USE_LLM_FRAUD] fraud LoRA GRPO step...')\n",
1420
+ " info = fraud_llm.grpo_step(rung_idx=rung_idx)\n",
1421
+ " new_theta = info['theta']\n",
1422
+ " fraud_agent.theta = dict(new_theta)\n",
1423
+ " fraud_agent.history.append(dict(fraud_agent.theta))\n",
1424
+ " round_fraud_fits.append(1.0 - info['def_reward_under_new_theta'])\n",
1425
+ " print(f' proposed θ={new_theta} | def_reward={info[\"def_reward_under_new_theta\"]:.3f}')\n",
1426
+ " else:\n",
1427
+ " for es in range(ES_STEPS_PER_ROUND):\n",
1428
+ " info = fraud_agent.es_step(_defender_action)\n",
1429
+ " round_fraud_fits.append(info['mean_fraud_fitness'])\n",
1430
+ " print(f' ES step {es+1}/{ES_STEPS_PER_ROUND}: mean_fitness={info[\"mean_fraud_fitness\"]:.3f}'\n",
1431
+ " f' best={info[\"best_fraud_fitness\"]:.3f} theta={info[\"theta\"]}')\n",
1432
  " fraud_round_fitness.append(float(np.mean(round_fraud_fits)) if round_fraud_fits else 0.0)\n",
1433
  " fraud_theta_history.append(dict(fraud_agent.theta))\n",
1434
  "\n",
1435
  " # Exploitability gap: how much WORSE the defender does against trained\n",
1436
+ " # fraud vs. against neutral fraud.\n",
1437
  " env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
1438
  " baseline_def = quick_defender_eval()\n",
1439
+ " fraud_agent.apply()\n",
1440
  " adv_def = quick_defender_eval()\n",
1441
  " gap = float(baseline_def - adv_def)\n",
1442
  " exploitability_log.append(gap)\n",
1443
  " print(f' exploitability gap: baseline_def={baseline_def:.3f} vs adv_def={adv_def:.3f} -> gap={gap:.3f}')\n",
1444
  "\n",
1445
+ "# ── Final league robustness telemetry ────────────────────────────────\n",
1446
+ "# Measure the trained defender against EVERY rung that was snapshotted.\n",
1447
+ "# A robust policy (good ladder-curriculum) scores well across rungs;\n",
1448
+ "# an over-fit one only scores well on the last. This is plotted in cell 22.\n",
1449
+ "print('\\n[league] measuring trained defender vs each league rung...')\n",
1450
+ "league_eval_rewards = []\n",
1451
+ "for rung in league.rungs:\n",
1452
+ " env_configure_adversary(**rung['theta'], strategy='mixed')\n",
1453
+ " score = quick_defender_eval()\n",
1454
+ " league_eval_rewards.append({'name': rung['name'], 'theta': rung['theta'],\n",
1455
+ " 'defender_reward': float(score)})\n",
1456
+ " print(f\" {rung['name']}: defender_reward={score:.3f} θ={rung['theta']}\")\n",
1457
+ "\n",
1458
+ "# Restore co-evolved fraud at the end so cell 20's trained_eval starts there.\n",
1459
+ "fraud_agent.apply()\n",
1460
+ "\n",
1461
  "print('\\nCo-training finished.')\n",
1462
+ "print(' ladder rung schedule :', ladder_round_rung)\n",
1463
+ "print(' league pool size :', len(league),\n",
1464
+ " '|', [r['name'] for r in league.rungs])\n",
1465
  "print(' defender_round_rewards:', defender_round_rewards)\n",
1466
+ "print(' fraud_round_fitness :', fraud_round_fitness)\n",
1467
+ "print(' exploitability_log :', exploitability_log)\n",
1468
  "\n",
1469
  "# Aliases for downstream cells\n",
1470
  "loss_history = loss_history_all\n",
 
1534
  "source": [
1535
  "import matplotlib.pyplot as plt\n",
1536
  "\n",
1537
+ "# 0. SFT warm-start loss\n",
1538
+ "if sft_loss_history:\n",
1539
+ " plt.figure(figsize=(8,4))\n",
1540
+ " plt.plot(sft_loss_history, marker='o', color='#a48', label='SFT loss')\n",
1541
+ " plt.xlabel('Logging step')\n",
1542
+ " plt.ylabel('Loss')\n",
1543
+ " plt.title('Stage 1 — SFT warm-start (heuristic imitation)')\n",
1544
+ " plt.legend()\n",
1545
+ " plt.tight_layout()\n",
1546
+ " plt.savefig('artifacts/sft_loss_curve.png', dpi=140)\n",
1547
+ " plt.show()\n",
1548
+ "\n",
1549
  "# 1. GRPO training reward (across all rounds)\n",
1550
  "if reward_log:\n",
1551
  " plt.figure(figsize=(8,4))\n",
1552
  " plt.plot(reward_log, label='GRPO mean reward per logging step')\n",
1553
  " plt.xlabel('Logging step (across all defender rounds)')\n",
1554
  " plt.ylabel('Reward')\n",
1555
+ " plt.title('Stage 2 — GRPO defender training reward')\n",
1556
  " plt.legend()\n",
1557
  " plt.tight_layout()\n",
1558
  " plt.savefig('artifacts/grpo_reward_curve.png', dpi=140)\n",
 
1570
  " plt.savefig('artifacts/grpo_training_loss.png', dpi=140)\n",
1571
  " plt.show()\n",
1572
  "\n",
1573
+ "# 2b. (Optional) Fraud-LLM GRPO loss + reward — only when USE_LLM_FRAUD=True\n",
1574
+ "if USE_LLM_FRAUD and fraud_llm is not None and fraud_llm.loss_history:\n",
1575
+ " fig, ax1 = plt.subplots(figsize=(8,4))\n",
1576
+ " ax1.plot(fraud_llm.loss_history, color='#c44', label='Fraud-LoRA GRPO loss')\n",
1577
+ " ax1.set_xlabel('Logging step (across all fraud rounds)')\n",
1578
+ " ax1.set_ylabel('Loss', color='#c44')\n",
1579
+ " if fraud_llm.reward_history:\n",
1580
+ " ax2 = ax1.twinx()\n",
1581
+ " ax2.plot(fraud_llm.reward_history, color='#48a',\n",
1582
+ " label='Fraud-LoRA GRPO reward (1 - def_reward)')\n",
1583
+ " ax2.set_ylabel('Reward', color='#48a')\n",
1584
+ " plt.title('Stage 2 — Fraud LoRA GRPO (dual-LLM mode)')\n",
1585
+ " fig.tight_layout()\n",
1586
+ " plt.savefig('artifacts/fraud_llm_grpo_curves.png', dpi=140)\n",
1587
+ " plt.show()\n",
1588
+ "\n",
1589
  "# 3. Co-evolution: defender reward vs fraud fitness per round\n",
1590
  "rounds_x = np.arange(1, len(defender_round_rewards) + 1)\n",
1591
  "fig, ax1 = plt.subplots(figsize=(8,4))\n",
 
1628
  " plt.savefig('artifacts/fraud_theta_trajectory.png', dpi=140)\n",
1629
  " plt.show()\n",
1630
  "\n",
1631
+ "# 6. Before vs After ── FIX D ──\n",
1632
+ "# Now shows FOUR bars so the comparison is fair AND informative:\n",
1633
+ "# * Random / Heuristic — baselines, eval'd vs neutral fraud (Fix A)\n",
1634
+ "# * Trained LLM (vs Neutral) — apples-to-apples with baselines (PRIMARY)\n",
1635
+ "# * Trained LLM (vs Co-Evo) — robustness against the hardest fraud seen\n",
1636
+ "labels = ['Random\\n(neutral)', 'Heuristic\\n(neutral)',\n",
1637
+ " 'Trained LLM\\n(neutral)', 'Trained LLM\\n(co-evolved)']\n",
1638
+ "values = [\n",
1639
+ " baseline_random['mean_reward'],\n",
1640
+ " baseline_heuristic['mean_reward'],\n",
1641
+ " trained_eval_neutral['mean_reward'],\n",
1642
+ " trained_eval['mean_reward'],\n",
1643
+ "]\n",
1644
+ "colors = ['#bbb','#88c','#4a8','#268']\n",
1645
+ "plt.figure(figsize=(8.5, 4.5))\n",
1646
+ "bars = plt.bar(labels, values, color=colors)\n",
1647
  "for b, v in zip(bars, values):\n",
1648
  " plt.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.3f}', ha='center')\n",
1649
  "plt.ylabel('Mean reward (frozen holdout)')\n",
1650
+ "plt.title('Before vs After Training (SFT + GRPO ladder co-evolution)')\n",
1651
  "plt.tight_layout()\n",
1652
  "plt.savefig('artifacts/before_after_rewards.png', dpi=140)\n",
1653
  "plt.show()\n",
1654
  "\n",
1655
+ "# 7a. Trained defender vs each LEAGUE rung (ladder robustness)\n",
1656
+ "# A \"good\" ladder run shows the trained defender scoring at-or-above the\n",
1657
+ "# heuristic baseline across ALL rungs (not just the latest). A spike on the\n",
1658
+ "# last rung only would be evidence of catastrophic forgetting.\n",
1659
+ "if league_eval_rewards:\n",
1660
+ " rung_names = [r['name'] for r in league_eval_rewards]\n",
1661
+ " rung_rewards = [r['defender_reward'] for r in league_eval_rewards]\n",
1662
+ " plt.figure(figsize=(8.5, 4))\n",
1663
+ " bars = plt.bar(rung_names, rung_rewards, color='#4a8')\n",
1664
+ " for b, v in zip(bars, rung_rewards):\n",
1665
+ " plt.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}',\n",
1666
+ " ha='center', fontsize=9)\n",
1667
+ " plt.axhline(baseline_heuristic['mean_reward'], color='#88c',\n",
1668
+ " linestyle='--', label=f\"Heuristic (neutral): {baseline_heuristic['mean_reward']:.3f}\")\n",
1669
+ " plt.axhline(baseline_random['mean_reward'], color='#aaa',\n",
1670
+ " linestyle=':', label=f\"Random (neutral): {baseline_random['mean_reward']:.3f}\")\n",
1671
+ " plt.xticks(rotation=20, ha='right', fontsize=8)\n",
1672
+ " plt.ylabel('Trained defender mean reward')\n",
1673
+ " plt.title('Ladder robustness: Trained LLM vs each league rung')\n",
1674
+ " plt.legend(fontsize=8)\n",
1675
+ " plt.tight_layout()\n",
1676
+ " plt.savefig('artifacts/league_robustness.png', dpi=140)\n",
1677
+ " plt.show()\n",
1678
+ "\n",
1679
+ "# 7. Per risk-bucket ── FIX D ──\n",
1680
+ "# Same 4-way comparison broken out by Low / Medium / High risk so you can\n",
1681
+ "# see if the trained model lifts performance in the hard buckets where\n",
1682
+ "# heuristic + random give up.\n",
1683
  "buckets = ['low', 'medium', 'high']\n",
1684
+ "rand_b = [baseline_random['bucket_means'][b] for b in buckets]\n",
1685
+ "heur_b = [baseline_heuristic['bucket_means'][b] for b in buckets]\n",
1686
+ "trN_b = [trained_eval_neutral['bucket_means'][b] for b in buckets]\n",
1687
+ "trC_b = [trained_eval['bucket_means'][b] for b in buckets]\n",
1688
  "x = np.arange(len(buckets))\n",
1689
+ "w = 0.20\n",
1690
+ "plt.figure(figsize=(9.5, 4.5))\n",
1691
+ "plt.bar(x - 1.5*w, rand_b, width=w, label='Random (neutral)', color='#bbb')\n",
1692
+ "plt.bar(x - 0.5*w, heur_b, width=w, label='Heuristic (neutral)', color='#88c')\n",
1693
+ "plt.bar(x + 0.5*w, trN_b, width=w, label='Trained LLM (neutral)', color='#4a8')\n",
1694
+ "plt.bar(x + 1.5*w, trC_b, width=w, label='Trained LLM (co-evolved)', color='#268')\n",
1695
  "plt.xticks(x, [b.title()+' Risk' for b in buckets])\n",
1696
  "plt.ylabel('Mean reward')\n",
1697
  "plt.title('Per Risk-Bucket Reward (frozen holdout)')\n",
1698
+ "plt.legend(loc='best', fontsize=8)\n",
1699
  "plt.tight_layout()\n",
1700
  "plt.savefig('artifacts/per_bucket_rewards.png', dpi=140)\n",
1701
  "plt.show()\n",
 
1705
  " 'model_id': MODEL_ID,\n",
1706
  " 'quick_mode': QUICK_MODE,\n",
1707
  " 'prompts_used': len(prompts),\n",
1708
+ " 'training_recipe': 'SFT(heuristic-imitation) -> ladder GRPO(rung-curriculum) ⇄ ES fraud (PFSP league)',\n",
1709
+ " 'sft_steps': SFT_STEPS,\n",
1710
+ " 'sft_lr': SFT_LR,\n",
1711
+ " 'sft_loss_history': sft_loss_history,\n",
1712
  " 'grpo_num_generations': GRPO_NUM_GENERATIONS,\n",
1713
  " 'rollout_steps_per_reward': ROLLOUT_STEPS_PER_REWARD,\n",
1714
  " 'n_rounds': N_ROUNDS,\n",
1715
  " 'grpo_steps_per_round': GRPO_STEPS_PER_ROUND,\n",
1716
  " 'es_steps_per_round': ES_STEPS_PER_ROUND,\n",
1717
  " 'es_population': ES_POPULATION,\n",
1718
+ " 'ladder_rungs': LADDER_RUNGS,\n",
1719
+ " 'ladder_round_rung': ladder_round_rung,\n",
1720
+ " 'league_pool': [r['name'] for r in league.rungs],\n",
1721
+ " 'league_past_sample_prob': LEAGUE_PAST_SAMPLE_PROB,\n",
1722
+ " 'league_eval_rewards': league_eval_rewards,\n",
1723
+ " 'use_llm_fraud': USE_LLM_FRAUD,\n",
1724
+ " 'fraud_llm_grpo_loss_history': (fraud_llm.loss_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
1725
+ " 'fraud_llm_grpo_reward_history': (fraud_llm.reward_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
1726
+ " 'fraud_llm_theta_history': (fraud_llm.theta_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
1727
  " 'baseline_random_mean_reward': baseline_random['mean_reward'],\n",
1728
  " 'baseline_heuristic_mean_reward': baseline_heuristic['mean_reward'],\n",
1729
+ " 'trained_mean_reward_neutral_fraud': trained_eval_neutral['mean_reward'],\n",
1730
+ " 'trained_mean_reward_coevolved_fraud': trained_eval['mean_reward'],\n",
1731
+ " 'reward_gain_vs_random': trained_eval_neutral['mean_reward'] - baseline_random['mean_reward'],\n",
1732
+ " 'reward_gain_vs_heuristic': trained_eval_neutral['mean_reward'] - baseline_heuristic['mean_reward'],\n",
1733
  " 'per_bucket': {\n",
1734
+ " 'random': baseline_random['bucket_means'],\n",
1735
+ " 'heuristic': baseline_heuristic['bucket_means'],\n",
1736
+ " 'trained_neutral': trained_eval_neutral['bucket_means'],\n",
1737
+ " 'trained_coevolved': trained_eval['bucket_means'],\n",
1738
  " },\n",
1739
  " 'defender_round_rewards': defender_round_rewards,\n",
1740
  " 'fraud_round_fitness': fraud_round_fitness,\n",
 
1744
  " 'grpo_reward_curve': reward_log,\n",
1745
  " 'grpo_loss_history': loss_history,\n",
1746
  " 'eval_per_episode': {\n",
1747
+ " 'random': baseline_random['per_episode_mean'],\n",
1748
+ " 'heuristic': baseline_heuristic['per_episode_mean'],\n",
1749
+ " 'trained_neutral': trained_eval_neutral['per_episode_mean'],\n",
1750
+ " 'trained_coevolved': trained_eval['per_episode_mean'],\n",
1751
  " },\n",
1752
  "}\n",
1753
  "with open('artifacts/run_summary.json', 'w', encoding='utf-8') as f:\n",
server/SmartPayEnv_environment.py CHANGED
@@ -504,8 +504,14 @@ class SmartpayenvEnvironment(Environment):
504
  base_reward = (0.4 * route_score) + (0.4 * fs) + (0.2 * rs)
505
 
506
  # League-style regret: penalize underperforming against moving challenger.
 
 
 
 
 
 
507
  challenger_regret = max(0.0, self._state.challenger_skill - base_reward)
508
- regret_penalty = 0.35 * challenger_regret
509
 
510
  # Anti-gaming check: repeatedly overusing manual review without quality gains.
511
  gaming_penalty = 0.0
@@ -513,8 +519,13 @@ class SmartpayenvEnvironment(Environment):
513
  self._state.anti_gaming_alerts += 1
514
  gaming_penalty = min(0.12, 0.02 * self._state.anti_gaming_alerts)
515
 
516
- # Curriculum bonus: reward robust performance under higher difficulty pressure.
517
- robustness_bonus = 0.06 * self._state.curriculum_level * max(0.0, base_reward - 0.55)
 
 
 
 
 
518
 
519
  # Norm punishment for delayed liabilities + self-improvement terms.
520
  final_reward = base_reward - (cb_amt / 150.0) - regret_penalty - gaming_penalty + robustness_bonus
 
504
  base_reward = (0.4 * route_score) + (0.4 * fs) + (0.2 * rs)
505
 
506
  # League-style regret: penalize underperforming against moving challenger.
507
+ # NOTE: coefficient was 0.35 — too crushing as a learning signal. A fresh
508
+ # GRPO policy with base_reward=0.3 would lose ~0.12 here, while a strong
509
+ # policy with base_reward=0.7 lost almost nothing. That's the wrong slope:
510
+ # it punished bad policies more than good ones, suppressing the gradient
511
+ # at the very start of training. 0.15 keeps the league-style pressure but
512
+ # leaves enough reward range for early learning.
513
  challenger_regret = max(0.0, self._state.challenger_skill - base_reward)
514
+ regret_penalty = 0.15 * challenger_regret
515
 
516
  # Anti-gaming check: repeatedly overusing manual review without quality gains.
517
  gaming_penalty = 0.0
 
519
  self._state.anti_gaming_alerts += 1
520
  gaming_penalty = min(0.12, 0.02 * self._state.anti_gaming_alerts)
521
 
522
+ # Curriculum bonus: reward robust performance.
523
+ # NOTE: was `0.06 * curriculum_level * ...` which is exactly 0.0 until the
524
+ # self-improvement loop has already lifted curriculum_level above 0 —
525
+ # a chicken-and-egg that gave bad policies no upside signal at all. The
526
+ # `(1.0 + curriculum_level)` factor activates the bonus from step 1
527
+ # (worth +0.10 * (base-0.5) immediately) and *grows* with curriculum.
528
+ robustness_bonus = 0.10 * (1.0 + self._state.curriculum_level) * max(0.0, base_reward - 0.5)
529
 
530
  # Norm punishment for delayed liabilities + self-improvement terms.
531
  final_reward = base_reward - (cb_amt / 150.0) - regret_penalty - gaming_penalty + robustness_bonus