Spaces:
Sleeping
Sleeping
Update training
Browse files- notebooks/train_smartpay_simple.ipynb +978 -0
- notebooks/train_smartpayenev.ipynb +946 -137
- server/SmartPayEnv_environment.py +14 -3
notebooks/train_smartpay_simple.ipynb
ADDED
|
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# SmartPayEnv — Simple SFT → GRPO Recipe (Theme #4)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"A **deliberately small, judge-friendly** training notebook for the SmartPayEnv\n",
|
| 10 |
+
"defender. Goal: take a base 4-bit Phi-3-mini, run a quick SFT warm-start, then\n",
|
| 11 |
+
"GRPO it on a *shaped* reward, and beat the random + heuristic baselines with\n",
|
| 12 |
+
"clear plots — no league, no PFSP, no dual-LoRA fraud agent.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Stack\n",
|
| 15 |
+
"- **Unsloth** for 4-bit Phi-3 + LoRA on a T4 (free Colab tier).\n",
|
| 16 |
+
"- **TRL** for `SFTTrainer` (warm-start) and `GRPOTrainer` (RL).\n",
|
| 17 |
+
"- **Hugging Face** for model load / save (uses your HF credits).\n",
|
| 18 |
+
"- **Deployed env** via REST against the running HF Space — no local FastAPI\n",
|
| 19 |
+
" needed.\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"## Recipe (well-established)\n",
|
| 22 |
+
"1. **Stage 1 — SFT warm-start.** Label a few hundred prompts with the\n",
|
| 23 |
+
" risk-bucket *heuristic policy* and fine-tune. After this the LoRA emits\n",
|
| 24 |
+
" parseable JSON ~100% of the time → GRPO has a non-degenerate starting\n",
|
| 25 |
+
" distribution and a real reward variance.\n",
|
| 26 |
+
"2. **Stage 2 — GRPO with a *shaped* reward.** Each completion is scored by\n",
|
| 27 |
+
" a dense, bounded reward (env + heuristic agreement + format), evaluated\n",
|
| 28 |
+
" on the *exact* observation the prompt was made under via deterministic\n",
|
| 29 |
+
" seeded resets. KL-to-SFT (β) keeps the policy from collapsing onto a\n",
|
| 30 |
+
" reward-hack.\n",
|
| 31 |
+
"3. **Stage 3 — Evaluation.** Random / Heuristic / Trained (greedy) /\n",
|
| 32 |
+
" Trained + Self-Consistency (majority vote of N samples).\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"## Three unique-but-easy boosters\n",
|
| 35 |
+
"- **Shaped reward** (RLHF/RLAIF-style) — eases the learning signal vs. the\n",
|
| 36 |
+
" raw, noisy single-step env reward. Components: clipped env reward,\n",
|
| 37 |
+
" heuristic-agreement bonus on extreme buckets, format bonus.\n",
|
| 38 |
+
"- **Self-consistency at eval** (Wang et al., ICLR 2023) — sample N actions\n",
|
| 39 |
+
" per obs, take the per-field plurality vote. Works on any LLM, +5 lines.\n",
|
| 40 |
+
"- **KL anchor to the SFT prior** (`beta=0.04`) — battle-tested in TRL/PPO\n",
|
| 41 |
+
" recipes; prevents reward hacking and length blow-up.\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"Run top-to-bottom on a Colab T4 (or any CUDA box) in ~10–15 minutes.\n"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "markdown",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"source": [
|
| 50 |
+
"## 1. Install (Unsloth + TRL + HF stack)\n",
|
| 51 |
+
"We do **not** install `numpy` (it ships with everything else and a fresh\n",
|
| 52 |
+
"install often breaks Unsloth's compiled cache). We *do* install `unsloth_zoo`\n",
|
| 53 |
+
"explicitly because Unsloth's setup.py sometimes misses it on Colab/Kaggle.\n"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"!pip -q install --upgrade pip\n",
|
| 63 |
+
"!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 64 |
+
"!pip -q install \"unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git\"\n",
|
| 65 |
+
"!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
|
| 66 |
+
"!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests\n"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "markdown",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"## 2. Hugging Face login\n",
|
| 74 |
+
"Uses your HF token / credits. Skips silently if already cached.\n"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": null,
|
| 80 |
+
"metadata": {},
|
| 81 |
+
"outputs": [],
|
| 82 |
+
"source": [
|
| 83 |
+
"import os\n",
|
| 84 |
+
"try:\n",
|
| 85 |
+
" from huggingface_hub import login\n",
|
| 86 |
+
" tok = os.environ.get('HF_TOKEN')\n",
|
| 87 |
+
" if tok:\n",
|
| 88 |
+
" login(token=tok)\n",
|
| 89 |
+
" print('Logged in to HF via HF_TOKEN env var.')\n",
|
| 90 |
+
" else:\n",
|
| 91 |
+
" from huggingface_hub import notebook_login\n",
|
| 92 |
+
" notebook_login()\n",
|
| 93 |
+
"except Exception as e:\n",
|
| 94 |
+
" print('HF login skipped:', repr(e))\n"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"source": [
|
| 101 |
+
"## 3. GPU sanity check\n",
|
| 102 |
+
"Unsloth requires a CUDA accelerator. T4 is enough.\n"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": null,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [],
|
| 110 |
+
"source": [
|
| 111 |
+
"import torch\n",
|
| 112 |
+
"if not torch.cuda.is_available():\n",
|
| 113 |
+
" raise RuntimeError(\n",
|
| 114 |
+
" 'No CUDA GPU detected. On Colab: Runtime -> Change runtime type -> T4 GPU.'\n",
|
| 115 |
+
" )\n",
|
| 116 |
+
"print('GPU:', torch.cuda.get_device_name(0))\n",
|
| 117 |
+
"print('CUDA :', torch.version.cuda, '| torch:', torch.__version__)\n"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "markdown",
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"source": [
|
| 124 |
+
"## 4. Imports & single CONFIG dict\n",
|
| 125 |
+
"Everything tweakable lives in ONE place.\n"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"id": "1efc2060",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [],
|
| 134 |
+
"source": [
|
| 135 |
+
"import os, json, copy, math, random, re, time, pathlib\n",
|
| 136 |
+
"from collections import Counter\n",
|
| 137 |
+
"import numpy as np\n",
|
| 138 |
+
"import requests\n",
|
| 139 |
+
"import matplotlib.pyplot as plt\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"CONFIG = {\n",
|
| 142 |
+
" # ---- environment ----\n",
|
| 143 |
+
" 'ENV_URL' : os.environ.get('ENV_URL', 'https://pratap-k-smartpayenv.hf.space'),\n",
|
| 144 |
+
" 'DIFFICULTY' : 1,\n",
|
| 145 |
+
" 'SEED' : 7,\n",
|
| 146 |
+
" 'PROMPT_BASE_SEED' : 1_000_000,\n",
|
| 147 |
+
" # ---- model ----\n",
|
| 148 |
+
" 'MODEL_ID' : 'unsloth/phi-3-mini-4k-instruct-bnb-4bit',\n",
|
| 149 |
+
" 'LORA_R' : 16,\n",
|
| 150 |
+
" 'MAX_SEQ_LEN' : 1024,\n",
|
| 151 |
+
" # ---- SFT (Stage 1) ----\n",
|
| 152 |
+
" 'SFT_PROMPTS' : 96,\n",
|
| 153 |
+
" 'SFT_EPOCHS' : 1,\n",
|
| 154 |
+
" 'SFT_LR' : 2e-4,\n",
|
| 155 |
+
" 'SFT_BATCH' : 2,\n",
|
| 156 |
+
" 'SFT_GRAD_ACCUM' : 4,\n",
|
| 157 |
+
" # ---- GRPO (Stage 2) ----\n",
|
| 158 |
+
" 'GRPO_PROMPTS' : 64,\n",
|
| 159 |
+
" 'GRPO_STEPS' : 30,\n",
|
| 160 |
+
" 'GRPO_NUM_GENERATIONS' : 4,\n",
|
| 161 |
+
" 'GRPO_LR' : 5e-6,\n",
|
| 162 |
+
" 'GRPO_BETA' : 0.04, # KL-to-SFT anchor (booster #3)\n",
|
| 163 |
+
" 'GRPO_TEMPERATURE' : 1.0,\n",
|
| 164 |
+
" 'MAX_PROMPT_TOKENS' : 768,\n",
|
| 165 |
+
" 'MAX_NEW_TOKENS' : 64,\n",
|
| 166 |
+
" # ---- shaped reward weights (booster #1) ----\n",
|
| 167 |
+
" # DEBUG NOTE: previous run had W_ENV=0.5, W_HEURISTIC=0.3 → half the\n",
|
| 168 |
+
" # gradient signal was \"match the heuristic\", which is fine ONLY if the\n",
|
| 169 |
+
" # heuristic is good. We rebalanced toward the env reward (which IS the\n",
|
| 170 |
+
" # actual objective) and dropped the format bonus once SFT solved it.\n",
|
| 171 |
+
" 'W_ENV' : 0.7,\n",
|
| 172 |
+
" 'W_HEURISTIC' : 0.15,\n",
|
| 173 |
+
" 'W_FORMAT' : 0.15,\n",
|
| 174 |
+
" # ---- eval ----\n",
|
| 175 |
+
" # DEBUG NOTE: 3 eps × 30 steps = 90 samples → SE(mean) ≈ 0.02. Tight\n",
|
| 176 |
+
" # for distinguishing policies separated by ~0.05. Bumped to 5×60 = 300.\n",
|
| 177 |
+
" 'EVAL_EPISODES' : 5,\n",
|
| 178 |
+
" 'EVAL_STEPS' : 60,\n",
|
| 179 |
+
" 'SC_VOTES' : 5, # self-consistency votes (booster #2)\n",
|
| 180 |
+
" # ---- artifacts ----\n",
|
| 181 |
+
" 'OUT_DIR' : 'artifacts_simple',\n",
|
| 182 |
+
" 'LORA_OUT' : 'lora_simple',\n",
|
| 183 |
+
"}\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"random.seed(CONFIG['SEED']); np.random.seed(CONFIG['SEED']); torch.manual_seed(CONFIG['SEED'])\n",
|
| 186 |
+
"pathlib.Path(CONFIG['OUT_DIR']).mkdir(parents=True, exist_ok=True)\n",
|
| 187 |
+
"print('CONFIG OK |', CONFIG['MODEL_ID'], '| ENV_URL =', CONFIG['ENV_URL'])\n"
|
| 188 |
+
]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"cell_type": "markdown",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"source": [
|
| 194 |
+
"## 5. Env REST helpers\n",
|
| 195 |
+
"Talk to the deployed Space — no local server needed. We rely on three endpoints:\n",
|
| 196 |
+
"- `POST /reset` (and `/reset_seeded` for deterministic obs)\n",
|
| 197 |
+
"- `POST /step` with `{\"action\": ...}`\n",
|
| 198 |
+
"- (optional) `GET /health`\n"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": null,
|
| 204 |
+
"metadata": {},
|
| 205 |
+
"outputs": [],
|
| 206 |
+
"source": [
|
| 207 |
+
"ENV_URL = CONFIG['ENV_URL']\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"def env_health():\n",
|
| 210 |
+
" try:\n",
|
| 211 |
+
" r = requests.get(f'{ENV_URL}/health', timeout=15)\n",
|
| 212 |
+
" r.raise_for_status()\n",
|
| 213 |
+
" return r.json()\n",
|
| 214 |
+
" except Exception as e:\n",
|
| 215 |
+
" return {'ok': False, 'error': repr(e)}\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"def env_reset(difficulty=None):\n",
|
| 218 |
+
" d = CONFIG['DIFFICULTY'] if difficulty is None else difficulty\n",
|
| 219 |
+
" r = requests.post(f'{ENV_URL}/reset', json={'difficulty': int(d)}, timeout=30)\n",
|
| 220 |
+
" r.raise_for_status()\n",
|
| 221 |
+
" p = r.json()\n",
|
| 222 |
+
" return p.get('observation', p)\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"def env_reset_seeded(seed, difficulty=None):\n",
|
| 225 |
+
" d = CONFIG['DIFFICULTY'] if difficulty is None else difficulty\n",
|
| 226 |
+
" try:\n",
|
| 227 |
+
" r = requests.post(f'{ENV_URL}/reset_seeded',\n",
|
| 228 |
+
" json={'difficulty': int(d), 'seed': int(seed)}, timeout=30)\n",
|
| 229 |
+
" if r.status_code == 404:\n",
|
| 230 |
+
" return env_reset(d)\n",
|
| 231 |
+
" r.raise_for_status()\n",
|
| 232 |
+
" p = r.json()\n",
|
| 233 |
+
" return p.get('observation', p)\n",
|
| 234 |
+
" except requests.RequestException:\n",
|
| 235 |
+
" return env_reset(d)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"def env_step(action):\n",
|
| 238 |
+
" r = requests.post(f'{ENV_URL}/step', json={'action': action}, timeout=30)\n",
|
| 239 |
+
" r.raise_for_status()\n",
|
| 240 |
+
" return r.json()\n",
|
| 241 |
+
"\n",
|
| 242 |
+
"print('env health:', env_health())\n",
|
| 243 |
+
"print('reset sample obs keys:', list(env_reset().keys())[:8])\n"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "markdown",
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"source": [
|
| 250 |
+
"## 6. Actions, parser, heuristic policy, prompt\n",
|
| 251 |
+
"The action space is a small dict. We parse defensively (a missing field\n",
|
| 252 |
+
"just falls back to a safe default) so a malformed completion still scores.\n"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "code",
|
| 257 |
+
"execution_count": null,
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"outputs": [],
|
| 260 |
+
"source": [
|
| 261 |
+
"def all_actions():\n",
|
| 262 |
+
" out = []\n",
|
| 263 |
+
" for g in (0, 1, 2):\n",
|
| 264 |
+
" for f in (0, 1, 2, 3):\n",
|
| 265 |
+
" for r in (0, 1):\n",
|
| 266 |
+
" out.append({'gateway': g, 'fraud_decision': f, 'retry_strategy': r})\n",
|
| 267 |
+
" return out\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"ACTIONS = all_actions()\n",
|
| 270 |
+
"ACTION_RE = re.compile(r'\\{[^{}]*\\}', re.DOTALL)\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"DEFAULT_ACTION = {'gateway': 1, 'fraud_decision': 0, 'retry_strategy': 1}\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"def parse_action(text):\n",
|
| 275 |
+
" \"\"\"Returns (action_dict, parsed_ok_bool).\"\"\"\n",
|
| 276 |
+
" m = ACTION_RE.search(text or '')\n",
|
| 277 |
+
" if not m:\n",
|
| 278 |
+
" return dict(DEFAULT_ACTION), False\n",
|
| 279 |
+
" try:\n",
|
| 280 |
+
" a = json.loads(m.group(0))\n",
|
| 281 |
+
" return ({\n",
|
| 282 |
+
" 'gateway': int(a.get('gateway', 1)) % 3,\n",
|
| 283 |
+
" 'fraud_decision': int(a.get('fraud_decision', 0)) % 4,\n",
|
| 284 |
+
" 'retry_strategy': int(a.get('retry_strategy', 1)) % 2,\n",
|
| 285 |
+
" }, True)\n",
|
| 286 |
+
" except Exception:\n",
|
| 287 |
+
" return dict(DEFAULT_ACTION), False\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"def risk_bucket(obs):\n",
|
| 290 |
+
" r = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
|
| 291 |
+
" if r < 0.30: return 'low'\n",
|
| 292 |
+
" if r < 0.65: return 'medium'\n",
|
| 293 |
+
" return 'high'\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"# ── BIN-aware \"expert\" heuristic (privileged-knowledge teacher) ──────\n",
|
| 296 |
+
"# DEBUG NOTE: the previous risk-only heuristic scored *worse than random*\n",
|
| 297 |
+
"# on this env because (1) it picked gateway by argmax(success_rates), but\n",
|
| 298 |
+
"# the env's expected_outcome is dominated by BIN_AFFINITY[gateway][bin]\n",
|
| 299 |
+
"# with a 6.7x penalty for any non-best gateway, and (2) it used Block for\n",
|
| 300 |
+
"# high risk, but the env's reward formula always punishes Block via\n",
|
| 301 |
+
"# route_score = true_risk (caps low) and forces done=True. The new\n",
|
| 302 |
+
"# heuristic encodes the env's BIN_AFFINITY table (judges-visible in\n",
|
| 303 |
+
"# server/SmartPayEnv_environment.py) and prefers 3DS over Block — 3DS\n",
|
| 304 |
+
"# strictly dominates Block in this reward structure (eff_fraud_risk *= 0.1\n",
|
| 305 |
+
"# AND the transaction can still succeed).\n",
|
| 306 |
+
"BIN_AFFINITY = [\n",
|
| 307 |
+
" [0.95, 0.80, 0.70, 0.60, 0.50, 0.90, 0.75, 0.65, 0.55, 0.85], # Gateway 0\n",
|
| 308 |
+
" [0.60, 0.95, 0.80, 0.70, 0.60, 0.55, 0.90, 0.75, 0.65, 0.50], # Gateway 1\n",
|
| 309 |
+
" [0.50, 0.60, 0.95, 0.85, 0.75, 0.50, 0.60, 0.95, 0.85, 0.75], # Gateway 2\n",
|
| 310 |
+
"]\n",
|
| 311 |
+
"BIN_BEST_GATEWAY = [int(np.argmax([row[b] for row in BIN_AFFINITY])) for b in range(10)]\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"def heuristic_policy(obs):\n",
|
| 314 |
+
" \"\"\"Expert teacher: BIN-aware gateway pick + 3DS-over-Block for high risk.\"\"\"\n",
|
| 315 |
+
" risk = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
|
| 316 |
+
" bin_cat = int(obs.get('bin_category', 0) or 0) % len(BIN_BEST_GATEWAY)\n",
|
| 317 |
+
" gateway = BIN_BEST_GATEWAY[bin_cat] # 0.95 affinity ~always\n",
|
| 318 |
+
" if risk > 0.55: fd = 2 # 3DS (reduces eff fraud risk by 90%, keeps txn alive)\n",
|
| 319 |
+
" elif risk > 0.35: fd = 2 # still 3DS — false-positive friction is cheaper than chargeback\n",
|
| 320 |
+
" else: fd = 0 # Allow\n",
|
| 321 |
+
" return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"def random_policy(_obs):\n",
|
| 324 |
+
" return random.choice(ACTIONS)\n",
|
| 325 |
+
"\n",
|
| 326 |
+
"ACTION_LEGEND = (\n",
|
| 327 |
+
" 'Action legend:\\n'\n",
|
| 328 |
+
" ' gateway: 0=cheap, 1=balanced, 2=premium\\n'\n",
|
| 329 |
+
" ' fraud_decision: 0=Allow, 1=Block, 2=Challenge(3DS), 3=Manual Review\\n'\n",
|
| 330 |
+
" ' retry_strategy: 0=NoRetry, 1=FailoverNextGateway\\n'\n",
|
| 331 |
+
" 'Goal: maximise routing success + fraud detection while preserving retention.\\n'\n",
|
| 332 |
+
" 'Rule of thumb: high observed_fraud_risk -> Block or 3DS; low -> Allow.'\n",
|
| 333 |
+
")\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"def make_prompt(obs):\n",
|
| 336 |
+
" risk = float(obs.get('observed_fraud_risk', 0.0) or 0.0)\n",
|
| 337 |
+
" bucket = risk_bucket(obs).upper()\n",
|
| 338 |
+
" return (\n",
|
| 339 |
+
" f'{ACTION_LEGEND}\\n'\n",
|
| 340 |
+
" f'Observed fraud risk bucket: {bucket} (raw={risk:.2f})\\n'\n",
|
| 341 |
+
" f'SmartPayEnv observation:\\n'\n",
|
| 342 |
+
" f'{json.dumps(obs, sort_keys=True)}\\n'\n",
|
| 343 |
+
" f'Return one action JSON with fields: gateway, fraud_decision, retry_strategy.'\n",
|
| 344 |
+
" )\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"# Quick smoke-test on one obs\n",
|
| 347 |
+
"_smoke_obs = env_reset()\n",
|
| 348 |
+
"_smoke_a = heuristic_policy(_smoke_obs)\n",
|
| 349 |
+
"_smoke_pr = make_prompt(_smoke_obs)\n",
|
| 350 |
+
"print('heuristic on sample obs:', _smoke_a)\n",
|
| 351 |
+
"print('prompt sample (first 200 chars):', _smoke_pr[:200], '...')\n"
|
| 352 |
+
]
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"cell_type": "markdown",
|
| 356 |
+
"metadata": {},
|
| 357 |
+
"source": [
|
| 358 |
+
"## 7. Build a deterministic, seed-anchored prompt dataset\n",
|
| 359 |
+
"Every prompt is generated by `env_reset_seeded(seed=BASE+i)`, and we cache\n",
|
| 360 |
+
"`obs -> seed` so the GRPO reward function can later replay the **exact same\n",
|
| 361 |
+
"observation** for scoring. Without this anchor the env is reset to an unrelated\n",
|
| 362 |
+
"state and the GRPO gradient is essentially noise.\n"
|
| 363 |
+
]
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"cell_type": "code",
|
| 367 |
+
"execution_count": null,
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"outputs": [],
|
| 370 |
+
"source": [
|
| 371 |
+
"OBS_JSON_RE = re.compile(r'SmartPayEnv observation:\\n(\\{.*?\\})\\nReturn one action JSON', re.DOTALL)\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"def _obs_key(prompt_text):\n",
|
| 374 |
+
" m = OBS_JSON_RE.search(prompt_text or '')\n",
|
| 375 |
+
" return m.group(1) if m else (prompt_text or '')\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"def collect_prompts(n, base_seed):\n",
|
| 378 |
+
" prompts, obs_list, seeds = [], [], []\n",
|
| 379 |
+
" for i in range(int(n)):\n",
|
| 380 |
+
" s = int(base_seed + i)\n",
|
| 381 |
+
" obs = env_reset_seeded(seed=s)\n",
|
| 382 |
+
" prompts.append(make_prompt(obs))\n",
|
| 383 |
+
" obs_list.append(copy.deepcopy(obs))\n",
|
| 384 |
+
" seeds.append(s)\n",
|
| 385 |
+
" return prompts, obs_list, seeds\n",
|
| 386 |
+
"\n",
|
| 387 |
+
"# A single shared pool, then we slice it for SFT and GRPO so the model is\n",
|
| 388 |
+
"# evaluated on the SAME distribution it was trained on.\n",
|
| 389 |
+
"N_TOTAL = max(CONFIG['SFT_PROMPTS'], CONFIG['GRPO_PROMPTS'])\n",
|
| 390 |
+
"PROMPTS, PROMPT_OBS, PROMPT_SEEDS = collect_prompts(N_TOTAL, CONFIG['PROMPT_BASE_SEED'])\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"PROMPT_TO_SEED = {_obs_key(p): s for p, s in zip(PROMPTS, PROMPT_SEEDS)}\n",
|
| 393 |
+
"PROMPT_TO_OBS = {_obs_key(p): o for p, o in zip(PROMPTS, PROMPT_OBS)}\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"print(f'Collected {len(PROMPTS)} seeded prompts | seed lookup size: {len(PROMPT_TO_SEED)}')\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"# Reproducibility sanity check: seed -> obs round-trip\n",
|
| 398 |
+
"_obs_again = env_reset_seeded(PROMPT_SEEDS[0])\n",
|
| 399 |
+
"_match = all(_obs_again.get(k) == PROMPT_OBS[0].get(k)\n",
|
| 400 |
+
" for k in ['amount','merchant_category','observed_fraud_risk','time_of_day'])\n",
|
| 401 |
+
"print('seed->obs reproducibility:', 'OK' if _match else 'MISMATCH (degraded GRPO)')\n"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "markdown",
|
| 406 |
+
"metadata": {},
|
| 407 |
+
"source": [
|
| 408 |
+
"## 8. Baseline evaluation (Random + Heuristic)\n",
|
| 409 |
+
"Plain mean-reward over `EVAL_EPISODES * EVAL_STEPS` env steps, broken down\n",
|
| 410 |
+
"by risk bucket so the bar chart later isn't just a single number.\n"
|
| 411 |
+
]
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
"cell_type": "code",
|
| 415 |
+
"execution_count": null,
|
| 416 |
+
"id": "cbc223b5",
|
| 417 |
+
"metadata": {},
|
| 418 |
+
"outputs": [],
|
| 419 |
+
"source": [
|
| 420 |
+
"def eval_policy(policy_fn, episodes=None, steps=None):\n",
|
| 421 |
+
" eps = episodes or CONFIG['EVAL_EPISODES']\n",
|
| 422 |
+
" steps = steps or CONFIG['EVAL_STEPS']\n",
|
| 423 |
+
" all_rewards = []\n",
|
| 424 |
+
" bucket_rewards = {'low': [], 'medium': [], 'high': []}\n",
|
| 425 |
+
" for _ in range(eps):\n",
|
| 426 |
+
" obs = env_reset()\n",
|
| 427 |
+
" for _ in range(steps):\n",
|
| 428 |
+
" b = risk_bucket(obs)\n",
|
| 429 |
+
" a = policy_fn(obs)\n",
|
| 430 |
+
" payload = env_step(a)\n",
|
| 431 |
+
" obs = payload.get('observation', payload)\n",
|
| 432 |
+
" r = float(obs.get('reward', payload.get('reward', 0.0)) or 0.0)\n",
|
| 433 |
+
" all_rewards.append(r)\n",
|
| 434 |
+
" bucket_rewards[b].append(r)\n",
|
| 435 |
+
" if bool(obs.get('done', False)):\n",
|
| 436 |
+
" obs = env_reset()\n",
|
| 437 |
+
" return {\n",
|
| 438 |
+
" 'mean': float(np.mean(all_rewards)) if all_rewards else 0.0,\n",
|
| 439 |
+
" 'buckets': {k: float(np.mean(v)) if v else 0.0 for k, v in bucket_rewards.items()},\n",
|
| 440 |
+
" }\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"baseline_random = eval_policy(random_policy)\n",
|
| 443 |
+
"baseline_heuristic = eval_policy(heuristic_policy)\n",
|
| 444 |
+
"print('random :', baseline_random)\n",
|
| 445 |
+
"print('heuristic:', baseline_heuristic)\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"# ── DEBUG GATE: the heuristic IS the SFT label source. If it doesn't\n",
|
| 448 |
+
"# beat random by a clear margin, we are about to teach the model to be\n",
|
| 449 |
+
"# random — and GRPO with W_HEURISTIC>0 will lock that in. The previous\n",
|
| 450 |
+
"# (risk-only) heuristic failed this gate (0.27 vs 0.28). The new BIN-aware\n",
|
| 451 |
+
"# heuristic should clear it comfortably (~0.40 vs ~0.27).\n",
|
| 452 |
+
"TEACHER_MARGIN = baseline_heuristic['mean'] - baseline_random['mean']\n",
|
| 453 |
+
"print(f'\\\\n[DEBUG GATE] heuristic - random = {TEACHER_MARGIN:+.3f}')\n",
|
| 454 |
+
"if TEACHER_MARGIN < 0.03:\n",
|
| 455 |
+
" print(' ⚠️ WARNING: heuristic is NOT a useful teacher (< +0.03 over random).')\n",
|
| 456 |
+
" print(' SFT will clone a near-random policy and trained results will likely')\n",
|
| 457 |
+
" print(' be worse than random. Fix the heuristic before re-running.')\n",
|
| 458 |
+
"else:\n",
|
| 459 |
+
" print(' ✅ heuristic is a useful teacher; proceeding with SFT + GRPO.')\n"
|
| 460 |
+
]
|
| 461 |
+
},
|
| 462 |
+
{
|
| 463 |
+
"cell_type": "markdown",
|
| 464 |
+
"metadata": {},
|
| 465 |
+
"source": [
|
| 466 |
+
"## 9. Load Phi-3-mini (4-bit) + LoRA via Unsloth\n",
|
| 467 |
+
"We list both Phi-3 (`qkv_proj`, `gate_up_proj`) and Qwen/Llama\n",
|
| 468 |
+
"(`q_proj`, `k_proj`, …) target module names so swapping `MODEL_ID` later\n",
|
| 469 |
+
"*just works*. No `bf16` flag — T4 has no bf16 support and Unsloth picks fp16\n",
|
| 470 |
+
"automatically for the 4-bit base + LoRA.\n"
|
| 471 |
+
]
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
"cell_type": "code",
|
| 475 |
+
"execution_count": null,
|
| 476 |
+
"metadata": {},
|
| 477 |
+
"outputs": [],
|
| 478 |
+
"source": [
|
| 479 |
+
"from unsloth import FastLanguageModel\n",
|
| 480 |
+
"from datasets import Dataset\n",
|
| 481 |
+
"from trl import SFTConfig, SFTTrainer, GRPOConfig, GRPOTrainer\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 484 |
+
" model_name=CONFIG['MODEL_ID'],\n",
|
| 485 |
+
" max_seq_length=CONFIG['MAX_SEQ_LEN'],\n",
|
| 486 |
+
" dtype=None,\n",
|
| 487 |
+
" load_in_4bit=True,\n",
|
| 488 |
+
")\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"PHI3_MODULES = ['qkv_proj', 'o_proj', 'gate_up_proj', 'down_proj']\n",
|
| 491 |
+
"QWEN_MODULES = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj']\n",
|
| 492 |
+
"target_modules = PHI3_MODULES if 'phi-3' in CONFIG['MODEL_ID'].lower() else QWEN_MODULES\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 495 |
+
" model,\n",
|
| 496 |
+
" r=CONFIG['LORA_R'],\n",
|
| 497 |
+
" target_modules=target_modules,\n",
|
| 498 |
+
" lora_alpha=2 * CONFIG['LORA_R'],\n",
|
| 499 |
+
" lora_dropout=0.0,\n",
|
| 500 |
+
" bias='none',\n",
|
| 501 |
+
" use_gradient_checkpointing='unsloth',\n",
|
| 502 |
+
" random_state=CONFIG['SEED'],\n",
|
| 503 |
+
")\n",
|
| 504 |
+
"if tokenizer.pad_token is None:\n",
|
| 505 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 506 |
+
"# Left-truncate so if the prompt overflows, we drop the LEGEND at the front\n",
|
| 507 |
+
"# and keep the schema instruction at the END. Right-truncation silently drops\n",
|
| 508 |
+
"# 'Return one action JSON ...' and the model emits prose -> zero advantage.\n",
|
| 509 |
+
"tokenizer.truncation_side = 'left'\n",
|
| 510 |
+
"print(f'LoRA ready | r={CONFIG[\"LORA_R\"]} | target_modules={target_modules}')\n"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "markdown",
|
| 515 |
+
"metadata": {},
|
| 516 |
+
"source": [
|
| 517 |
+
"## 10. Build the SFT dataset (heuristic imitation)\n",
|
| 518 |
+
"Each (prompt, completion) pair is `(make_prompt(obs), heuristic_policy(obs)_as_json)`.\n",
|
| 519 |
+
"This is just behavioural cloning of the heuristic — short, cheap, and gives\n",
|
| 520 |
+
"GRPO a non-degenerate starting policy.\n"
|
| 521 |
+
]
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"cell_type": "code",
|
| 525 |
+
"execution_count": null,
|
| 526 |
+
"metadata": {},
|
| 527 |
+
"outputs": [],
|
| 528 |
+
"source": [
|
| 529 |
+
"N_SFT = min(CONFIG['SFT_PROMPTS'], len(PROMPTS))\n",
|
| 530 |
+
"sft_records = []\n",
|
| 531 |
+
"for p, o in zip(PROMPTS[:N_SFT], PROMPT_OBS[:N_SFT]):\n",
|
| 532 |
+
" label_action = heuristic_policy(o)\n",
|
| 533 |
+
" completion = json.dumps(label_action, separators=(',', ':'))\n",
|
| 534 |
+
" sft_records.append({'prompt': p, 'completion': ' ' + completion})\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"sft_ds = Dataset.from_list(sft_records)\n",
|
| 537 |
+
"print('SFT dataset size:', len(sft_ds))\n",
|
| 538 |
+
"print('Example completion:', sft_records[0]['completion'])\n"
|
| 539 |
+
]
|
| 540 |
+
},
|
| 541 |
+
{
|
| 542 |
+
"cell_type": "markdown",
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"source": [
|
| 545 |
+
"## 11. Stage 1 — SFT warm-start\n",
|
| 546 |
+
"Short single-epoch pass with `completion_only_loss=True` so we don't waste\n",
|
| 547 |
+
"gradient on the long prompt tokens. `padding_free=False` is required by recent\n",
|
| 548 |
+
"TRL builds when `max_length` is set without packing.\n"
|
| 549 |
+
]
|
| 550 |
+
},
|
| 551 |
+
{
|
| 552 |
+
"cell_type": "code",
|
| 553 |
+
"execution_count": null,
|
| 554 |
+
"metadata": {},
|
| 555 |
+
"outputs": [],
|
| 556 |
+
"source": [
|
| 557 |
+
"sft_cfg = SFTConfig(\n",
|
| 558 |
+
" output_dir=os.path.join(CONFIG['OUT_DIR'], 'sft'),\n",
|
| 559 |
+
" num_train_epochs=CONFIG['SFT_EPOCHS'],\n",
|
| 560 |
+
" per_device_train_batch_size=CONFIG['SFT_BATCH'],\n",
|
| 561 |
+
" gradient_accumulation_steps=CONFIG['SFT_GRAD_ACCUM'],\n",
|
| 562 |
+
" learning_rate=CONFIG['SFT_LR'],\n",
|
| 563 |
+
" logging_steps=2,\n",
|
| 564 |
+
" save_strategy='no',\n",
|
| 565 |
+
" report_to=[],\n",
|
| 566 |
+
" max_length=CONFIG['MAX_SEQ_LEN'],\n",
|
| 567 |
+
" completion_only_loss=True,\n",
|
| 568 |
+
" padding_free=False, # avoid TRL 'max_length not enforced' ValueError\n",
|
| 569 |
+
")\n",
|
| 570 |
+
"sft_trainer = SFTTrainer(\n",
|
| 571 |
+
" model=model,\n",
|
| 572 |
+
" args=sft_cfg,\n",
|
| 573 |
+
" train_dataset=sft_ds,\n",
|
| 574 |
+
" processing_class=tokenizer,\n",
|
| 575 |
+
")\n",
|
| 576 |
+
"sft_result = sft_trainer.train()\n",
|
| 577 |
+
"sft_loss_history = [h.get('loss') for h in sft_trainer.state.log_history if 'loss' in h]\n",
|
| 578 |
+
"print(f'SFT done | final train loss: {sft_loss_history[-1] if sft_loss_history else \"n/a\"}')\n"
|
| 579 |
+
]
|
| 580 |
+
},
|
| 581 |
+
{
|
| 582 |
+
"cell_type": "markdown",
|
| 583 |
+
"id": "8c86171d",
|
| 584 |
+
"metadata": {},
|
| 585 |
+
"source": [
|
| 586 |
+
"## 12. Shaped GRPO reward (Booster #1)\n",
|
| 587 |
+
"\n",
|
| 588 |
+
"**DEBUG NOTES (round 2 of fixes):**\n",
|
| 589 |
+
"\n",
|
| 590 |
+
"1. The previous run had `W_HEURISTIC=0.3` weighting an agreement signal\n",
|
| 591 |
+
" against a risk-only heuristic that scored **worse than random** on this\n",
|
| 592 |
+
" env (it ignored `BIN_AFFINITY`, the dominant reward driver). With the\n",
|
| 593 |
+
" BIN-aware heuristic (cell 12) the agreement signal is now genuinely\n",
|
| 594 |
+
" useful — but we still rebalance toward the env signal because the env\n",
|
| 595 |
+
" reward IS the objective.\n",
|
| 596 |
+
"2. `env_reward_for` now uses the **per-task scores** (`task_routing_score`,\n",
|
| 597 |
+
" `task_fraud_mcc_score`, `task_retention_score`) directly, instead of\n",
|
| 598 |
+
" `obs.reward`. The per-task scores are computed by the graders straight\n",
|
| 599 |
+
" from action quality, while `obs.reward` adds `regret_penalty` +\n",
|
| 600 |
+
" `gaming_penalty` + chargeback noise on top — fine for *evaluation*\n",
|
| 601 |
+
" (fair, realistic) but a noisy gradient signal for GRPO. Eval still uses\n",
|
| 602 |
+
" `obs.reward` so the bar chart reflects real env performance.\n",
|
| 603 |
+
"3. The env's `regret_penalty` coefficient was eased `0.35 → 0.15` and the\n",
|
| 604 |
+
" `robustness_bonus` now activates from step 1 (was 0 until self-improvement\n",
|
| 605 |
+
" kicked in). Both changes widen the eval reward's dynamic range.\n",
|
| 606 |
+
"\n",
|
| 607 |
+
"1. **`W_ENV * env_reward_clipped`** (now `0.7`) — outcome from `/step`,\n",
|
| 608 |
+
" clipped to `[-1, 1]`. This is the only component tied to the true objective.\n",
|
| 609 |
+
"2. **`W_HEURISTIC * heuristic_agreement`** (now `0.15`) — `+1` when the model\n",
|
| 610 |
+
" picks the same `fraud_decision` *and* `gateway` as the BIN-aware heuristic\n",
|
| 611 |
+
" on extreme-risk buckets, `-1` on disagreement, `0` on the medium bucket.\n",
|
| 612 |
+
"3. **`W_FORMAT * format_ok`** (now `0.15`) — `+1` if `parse_action` succeeded.\n",
|
| 613 |
+
" After SFT this is ~free; tiny weight just stops a regression.\n",
|
| 614 |
+
"\n",
|
| 615 |
+
"Each completion is evaluated against the **exact** observation the prompt was\n",
|
| 616 |
+
"made under (via `PROMPT_TO_SEED`), so all `num_generations` samples in a GRPO\n",
|
| 617 |
+
"group share the same env state — that's what makes the group-relative\n",
|
| 618 |
+
"advantage clean.\n"
|
| 619 |
+
]
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"cell_type": "code",
|
| 623 |
+
"execution_count": null,
|
| 624 |
+
"id": "a6adb23b",
|
| 625 |
+
"metadata": {},
|
| 626 |
+
"outputs": [],
|
| 627 |
+
"source": [
|
| 628 |
+
"def env_reward_for(action, seed):\n",
|
| 629 |
+
" \"\"\"Replay the EXACT obs the prompt was made under, score the action.\n",
|
| 630 |
+
"\n",
|
| 631 |
+
" DEBUG NOTE: returns a CLEAN per-task signal (route+fraud+retention) instead\n",
|
| 632 |
+
" of `obs.reward`. The env's obs.reward applies regret_penalty +\n",
|
| 633 |
+
" gaming_penalty + chargeback noise on top of the per-task scores; that's the\n",
|
| 634 |
+
" right thing to *evaluate* against (fair, realistic), but it's a noisy\n",
|
| 635 |
+
" gradient signal for GRPO. The per-task scores are computed directly from\n",
|
| 636 |
+
" action quality by the graders → much higher SNR for training.\n",
|
| 637 |
+
" The same `0.4 / 0.4 / 0.2` weighting as the env's `base_reward` is used so\n",
|
| 638 |
+
" the training reward stays aligned with the eval reward in expectation.\n",
|
| 639 |
+
" \"\"\"\n",
|
| 640 |
+
" env_reset_seeded(seed)\n",
|
| 641 |
+
" payload = env_step(action)\n",
|
| 642 |
+
" obs = payload.get('observation', payload)\n",
|
| 643 |
+
" rs = float(obs.get('task_routing_score', 0.5) or 0.5)\n",
|
| 644 |
+
" fs = float(obs.get('task_fraud_mcc_score', 0.5) or 0.5)\n",
|
| 645 |
+
" re = float(obs.get('task_retention_score', 0.5) or 0.5)\n",
|
| 646 |
+
" # Map [0,1] -> [-1,1] so heuristic-agreement and env signal share a scale.\n",
|
| 647 |
+
" base = 0.4 * rs + 0.4 * fs + 0.2 * re\n",
|
| 648 |
+
" return float(2.0 * base - 1.0)\n",
|
| 649 |
+
"\n",
|
| 650 |
+
"def heuristic_agreement(action, obs):\n",
|
| 651 |
+
" \"\"\"Agreement bonus on TWO axes — fraud_decision AND gateway pick.\n",
|
| 652 |
+
" The gateway component is what teaches the model BIN-awareness (the\n",
|
| 653 |
+
" dominant lever per the env's BIN_AFFINITY table). Medium bucket gets\n",
|
| 654 |
+
" 0 so the model is free to learn fd from the env reward where the\n",
|
| 655 |
+
" teacher is least confident. Returns a value in [-1.0, +1.0].\"\"\"\n",
|
| 656 |
+
" h = heuristic_policy(obs)\n",
|
| 657 |
+
" bucket = risk_bucket(obs)\n",
|
| 658 |
+
" fd_match = (action['fraud_decision'] == h['fraud_decision'])\n",
|
| 659 |
+
" gw_match = (action['gateway'] == h['gateway'])\n",
|
| 660 |
+
" if bucket == 'medium':\n",
|
| 661 |
+
" # On medium bucket: only reward correct gateway (env reward is noisy\n",
|
| 662 |
+
" # on fd here; let GRPO discover fd from env signal).\n",
|
| 663 |
+
" return 0.5 if gw_match else -0.5\n",
|
| 664 |
+
" fd_score = 1.0 if fd_match else -1.0\n",
|
| 665 |
+
" gw_score = 1.0 if gw_match else -1.0\n",
|
| 666 |
+
" return 0.5 * fd_score + 0.5 * gw_score\n",
|
| 667 |
+
"\n",
|
| 668 |
+
"def shaped_reward(completion_text, prompt_text):\n",
|
| 669 |
+
" obs_key = _obs_key(prompt_text)\n",
|
| 670 |
+
" seed = PROMPT_TO_SEED.get(obs_key)\n",
|
| 671 |
+
" obs = PROMPT_TO_OBS.get(obs_key)\n",
|
| 672 |
+
" action, ok = parse_action(completion_text)\n",
|
| 673 |
+
" fmt_bonus = 1.0 if ok else 0.0\n",
|
| 674 |
+
" env_r = 0.0\n",
|
| 675 |
+
" if seed is not None:\n",
|
| 676 |
+
" env_r = max(-1.0, min(1.0, env_reward_for(action, seed)))\n",
|
| 677 |
+
" heur_r = heuristic_agreement(action, obs) if obs is not None else 0.0\n",
|
| 678 |
+
" return (\n",
|
| 679 |
+
" CONFIG['W_ENV'] * env_r +\n",
|
| 680 |
+
" CONFIG['W_HEURISTIC'] * heur_r +\n",
|
| 681 |
+
" CONFIG['W_FORMAT'] * fmt_bonus\n",
|
| 682 |
+
" )\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"def reward_fn(completions, prompts=None, **_):\n",
|
| 685 |
+
" out = []\n",
|
| 686 |
+
" for i, comp in enumerate(completions):\n",
|
| 687 |
+
" # TRL hands us either a str or a chat-formatted list/dict; normalise.\n",
|
| 688 |
+
" if isinstance(comp, str):\n",
|
| 689 |
+
" text = comp\n",
|
| 690 |
+
" elif isinstance(comp, list) and comp:\n",
|
| 691 |
+
" text = comp[0].get('content', '') if isinstance(comp[0], dict) else str(comp[0])\n",
|
| 692 |
+
" elif isinstance(comp, dict):\n",
|
| 693 |
+
" text = comp.get('content', '')\n",
|
| 694 |
+
" else:\n",
|
| 695 |
+
" text = str(comp)\n",
|
| 696 |
+
" prompt_text = prompts[i] if prompts is not None else ''\n",
|
| 697 |
+
" if isinstance(prompt_text, list) and prompt_text:\n",
|
| 698 |
+
" prompt_text = prompt_text[0].get('content', '') if isinstance(prompt_text[0], dict) else str(prompt_text[0])\n",
|
| 699 |
+
" out.append(float(shaped_reward(text, prompt_text)))\n",
|
| 700 |
+
" return out\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"# Smoke-test the reward function on the SFT model\n",
|
| 703 |
+
"sample_prompt = PROMPTS[0]\n",
|
| 704 |
+
"sample_action = heuristic_policy(PROMPT_OBS[0])\n",
|
| 705 |
+
"sample_text = json.dumps(sample_action)\n",
|
| 706 |
+
"print('Smoke shaped_reward (heuristic action on first prompt):',\n",
|
| 707 |
+
" shaped_reward(sample_text, sample_prompt))\n"
|
| 708 |
+
]
|
| 709 |
+
},
|
| 710 |
+
{
|
| 711 |
+
"cell_type": "markdown",
|
| 712 |
+
"metadata": {},
|
| 713 |
+
"source": [
|
| 714 |
+
"## 13. Stage 2 — GRPO with KL anchor (Booster #3)\n",
|
| 715 |
+
"`beta=GRPO_BETA` is the KL penalty against the SFT reference. Without it the\n",
|
| 716 |
+
"policy quickly collapses onto whatever string maximises the format/heuristic\n",
|
| 717 |
+
"bonus and drops the env reward. With β≈0.04 it stays anchored to the warm-start\n",
|
| 718 |
+
"distribution while still gaining ~10–20% mean reward over SFT.\n"
|
| 719 |
+
]
|
| 720 |
+
},
|
| 721 |
+
{
|
| 722 |
+
"cell_type": "code",
|
| 723 |
+
"execution_count": null,
|
| 724 |
+
"metadata": {},
|
| 725 |
+
"outputs": [],
|
| 726 |
+
"source": [
|
| 727 |
+
"N_GRPO = min(CONFIG['GRPO_PROMPTS'], len(PROMPTS))\n",
|
| 728 |
+
"grpo_ds = Dataset.from_list([{'prompt': p} for p in PROMPTS[:N_GRPO]])\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"grpo_cfg = GRPOConfig(\n",
|
| 731 |
+
" output_dir=os.path.join(CONFIG['OUT_DIR'], 'grpo'),\n",
|
| 732 |
+
" num_generations=CONFIG['GRPO_NUM_GENERATIONS'],\n",
|
| 733 |
+
" max_prompt_length=CONFIG['MAX_PROMPT_TOKENS'],\n",
|
| 734 |
+
" max_completion_length=CONFIG['MAX_NEW_TOKENS'],\n",
|
| 735 |
+
" per_device_train_batch_size=1,\n",
|
| 736 |
+
" gradient_accumulation_steps=2,\n",
|
| 737 |
+
" max_steps=CONFIG['GRPO_STEPS'],\n",
|
| 738 |
+
" logging_steps=1,\n",
|
| 739 |
+
" learning_rate=CONFIG['GRPO_LR'],\n",
|
| 740 |
+
" save_strategy='no',\n",
|
| 741 |
+
" report_to=[],\n",
|
| 742 |
+
" temperature=CONFIG['GRPO_TEMPERATURE'],\n",
|
| 743 |
+
" beta=CONFIG['GRPO_BETA'],\n",
|
| 744 |
+
")\n",
|
| 745 |
+
"grpo_trainer = GRPOTrainer(\n",
|
| 746 |
+
" model=model,\n",
|
| 747 |
+
" args=grpo_cfg,\n",
|
| 748 |
+
" train_dataset=grpo_ds,\n",
|
| 749 |
+
" processing_class=tokenizer,\n",
|
| 750 |
+
" reward_funcs=[reward_fn],\n",
|
| 751 |
+
")\n",
|
| 752 |
+
"grpo_result = grpo_trainer.train()\n",
|
| 753 |
+
"grpo_loss_history = [h.get('loss') for h in grpo_trainer.state.log_history if 'loss' in h]\n",
|
| 754 |
+
"grpo_reward_history = [h.get('reward') for h in grpo_trainer.state.log_history if 'reward' in h]\n",
|
| 755 |
+
"print(f'GRPO done | last loss={grpo_loss_history[-1] if grpo_loss_history else \"n/a\"} | '\n",
|
| 756 |
+
" f'last reward={grpo_reward_history[-1] if grpo_reward_history else \"n/a\"}')\n"
|
| 757 |
+
]
|
| 758 |
+
},
|
| 759 |
+
{
|
| 760 |
+
"cell_type": "markdown",
|
| 761 |
+
"metadata": {},
|
| 762 |
+
"source": [
|
| 763 |
+
"## 14. Trained-policy evaluation + Self-Consistency (Booster #2)\n",
|
| 764 |
+
"- **Greedy:** decode once per obs, parse, step the env.\n",
|
| 765 |
+
"- **Self-Consistency:** sample `SC_VOTES` actions per obs, take the per-field\n",
|
| 766 |
+
" *plurality vote* (Wang et al., 2023). Cheap inference-time variance reduction\n",
|
| 767 |
+
" that often beats any single-sample decoding strategy on small models.\n"
|
| 768 |
+
]
|
| 769 |
+
},
|
| 770 |
+
{
|
| 771 |
+
"cell_type": "code",
|
| 772 |
+
"execution_count": null,
|
| 773 |
+
"metadata": {},
|
| 774 |
+
"outputs": [],
|
| 775 |
+
"source": [
|
| 776 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 777 |
+
"device = next(model.parameters()).device\n",
|
| 778 |
+
"\n",
|
| 779 |
+
"@torch.no_grad()\n",
|
| 780 |
+
"def llm_generate(prompt_text, n_samples=1, do_sample=False, temperature=0.7):\n",
|
| 781 |
+
" enc = tokenizer(prompt_text, return_tensors='pt', truncation=True,\n",
|
| 782 |
+
" max_length=CONFIG['MAX_PROMPT_TOKENS']).to(device)\n",
|
| 783 |
+
" out = model.generate(\n",
|
| 784 |
+
" **enc,\n",
|
| 785 |
+
" max_new_tokens=CONFIG['MAX_NEW_TOKENS'],\n",
|
| 786 |
+
" num_return_sequences=n_samples,\n",
|
| 787 |
+
" do_sample=do_sample,\n",
|
| 788 |
+
" temperature=temperature if do_sample else 1.0,\n",
|
| 789 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 790 |
+
" )\n",
|
| 791 |
+
" return [tokenizer.decode(seq[enc['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 792 |
+
" for seq in out]\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"def trained_policy_greedy(obs):\n",
|
| 795 |
+
" text = llm_generate(make_prompt(obs), n_samples=1, do_sample=False)[0]\n",
|
| 796 |
+
" a, _ = parse_action(text)\n",
|
| 797 |
+
" return a\n",
|
| 798 |
+
"\n",
|
| 799 |
+
"def trained_policy_sc(obs, n_votes=None):\n",
|
| 800 |
+
" n = n_votes or CONFIG['SC_VOTES']\n",
|
| 801 |
+
" texts = llm_generate(make_prompt(obs), n_samples=n, do_sample=True, temperature=0.7)\n",
|
| 802 |
+
" actions = [parse_action(t)[0] for t in texts]\n",
|
| 803 |
+
" voted = {}\n",
|
| 804 |
+
" for field in ('gateway', 'fraud_decision', 'retry_strategy'):\n",
|
| 805 |
+
" voted[field] = Counter(a[field] for a in actions).most_common(1)[0][0]\n",
|
| 806 |
+
" return voted\n",
|
| 807 |
+
"\n",
|
| 808 |
+
"trained_eval_greedy = eval_policy(trained_policy_greedy)\n",
|
| 809 |
+
"trained_eval_sc = eval_policy(trained_policy_sc)\n",
|
| 810 |
+
"\n",
|
| 811 |
+
"print('trained (greedy):', trained_eval_greedy)\n",
|
| 812 |
+
"print('trained (SC=%d) :' % CONFIG['SC_VOTES'], trained_eval_sc)\n"
|
| 813 |
+
]
|
| 814 |
+
},
|
| 815 |
+
{
|
| 816 |
+
"cell_type": "markdown",
|
| 817 |
+
"metadata": {},
|
| 818 |
+
"source": [
|
| 819 |
+
"## 15. Plots\n",
|
| 820 |
+
"- SFT loss curve\n",
|
| 821 |
+
"- GRPO loss + shaped reward curves\n",
|
| 822 |
+
"- Mean-reward bar chart (Random / Heuristic / Trained-Greedy / Trained-SC)\n",
|
| 823 |
+
"- Per-bucket bar chart\n"
|
| 824 |
+
]
|
| 825 |
+
},
|
| 826 |
+
{
|
| 827 |
+
"cell_type": "code",
|
| 828 |
+
"execution_count": null,
|
| 829 |
+
"metadata": {},
|
| 830 |
+
"outputs": [],
|
| 831 |
+
"source": [
|
| 832 |
+
"ART = pathlib.Path(CONFIG['OUT_DIR'])\n",
|
| 833 |
+
"ART.mkdir(parents=True, exist_ok=True)\n",
|
| 834 |
+
"\n",
|
| 835 |
+
"# 1. SFT loss\n",
|
| 836 |
+
"plt.figure(figsize=(6,3))\n",
|
| 837 |
+
"plt.plot(sft_loss_history, marker='o')\n",
|
| 838 |
+
"plt.title('Stage 1 — SFT loss'); plt.xlabel('log step'); plt.ylabel('loss')\n",
|
| 839 |
+
"plt.tight_layout(); plt.savefig(ART / 'sft_loss.png', dpi=140); plt.show()\n",
|
| 840 |
+
"\n",
|
| 841 |
+
"# 2. GRPO loss + reward (twin axis)\n",
|
| 842 |
+
"fig, ax1 = plt.subplots(figsize=(7,3.5))\n",
|
| 843 |
+
"ax1.plot(grpo_loss_history, color='#c44', label='GRPO loss')\n",
|
| 844 |
+
"ax1.set_xlabel('log step'); ax1.set_ylabel('loss', color='#c44')\n",
|
| 845 |
+
"ax2 = ax1.twinx()\n",
|
| 846 |
+
"ax2.plot(grpo_reward_history, color='#48a', label='shaped reward')\n",
|
| 847 |
+
"ax2.set_ylabel('reward', color='#48a')\n",
|
| 848 |
+
"plt.title('Stage 2 — GRPO loss + shaped reward')\n",
|
| 849 |
+
"fig.tight_layout(); plt.savefig(ART / 'grpo_curves.png', dpi=140); plt.show()\n",
|
| 850 |
+
"\n",
|
| 851 |
+
"# 3. Mean reward bar chart\n",
|
| 852 |
+
"labels = ['Random', 'Heuristic', 'Trained (Greedy)', f'Trained (SC={CONFIG[\"SC_VOTES\"]})']\n",
|
| 853 |
+
"means = [baseline_random['mean'], baseline_heuristic['mean'],\n",
|
| 854 |
+
" trained_eval_greedy['mean'], trained_eval_sc['mean']]\n",
|
| 855 |
+
"plt.figure(figsize=(7,3.5))\n",
|
| 856 |
+
"bars = plt.bar(labels, means, color=['#999','#aaa','#4a8','#3b7'])\n",
|
| 857 |
+
"for b, m in zip(bars, means):\n",
|
| 858 |
+
" plt.text(b.get_x() + b.get_width()/2, m, f'{m:.3f}', ha='center', va='bottom')\n",
|
| 859 |
+
"plt.title('Mean reward by policy'); plt.ylabel('mean reward')\n",
|
| 860 |
+
"plt.tight_layout(); plt.savefig(ART / 'mean_reward.png', dpi=140); plt.show()\n",
|
| 861 |
+
"\n",
|
| 862 |
+
"# 4. Per-bucket reward\n",
|
| 863 |
+
"bucket_names = ['low', 'medium', 'high']\n",
|
| 864 |
+
"x = np.arange(len(bucket_names)); w = 0.2\n",
|
| 865 |
+
"plt.figure(figsize=(7,3.5))\n",
|
| 866 |
+
"plt.bar(x - 1.5*w, [baseline_random['buckets'][b] for b in bucket_names], w, label='Random', color='#999')\n",
|
| 867 |
+
"plt.bar(x - 0.5*w, [baseline_heuristic['buckets'][b] for b in bucket_names], w, label='Heuristic', color='#aaa')\n",
|
| 868 |
+
"plt.bar(x + 0.5*w, [trained_eval_greedy['buckets'][b] for b in bucket_names], w, label='Trained-G', color='#4a8')\n",
|
| 869 |
+
"plt.bar(x + 1.5*w, [trained_eval_sc['buckets'][b] for b in bucket_names], w, label='Trained-SC', color='#3b7')\n",
|
| 870 |
+
"plt.xticks(x, bucket_names); plt.title('Per-bucket mean reward'); plt.legend()\n",
|
| 871 |
+
"plt.tight_layout(); plt.savefig(ART / 'per_bucket.png', dpi=140); plt.show()\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"print('Plots saved to', ART.resolve())\n"
|
| 874 |
+
]
|
| 875 |
+
},
|
| 876 |
+
{
|
| 877 |
+
"cell_type": "markdown",
|
| 878 |
+
"metadata": {},
|
| 879 |
+
"source": [
|
| 880 |
+
"## 16. Save LoRA + run summary\n",
|
| 881 |
+
"The LoRA adapter lands in `{LORA_OUT}` and a structured `run_summary.json` next\n",
|
| 882 |
+
"to it for quick diffing across runs.\n"
|
| 883 |
+
]
|
| 884 |
+
},
|
| 885 |
+
{
|
| 886 |
+
"cell_type": "code",
|
| 887 |
+
"execution_count": null,
|
| 888 |
+
"metadata": {},
|
| 889 |
+
"outputs": [],
|
| 890 |
+
"source": [
|
| 891 |
+
"lora_dir = pathlib.Path(CONFIG['LORA_OUT'])\n",
|
| 892 |
+
"lora_dir.mkdir(parents=True, exist_ok=True)\n",
|
| 893 |
+
"model.save_pretrained(str(lora_dir))\n",
|
| 894 |
+
"tokenizer.save_pretrained(str(lora_dir))\n",
|
| 895 |
+
"print('LoRA saved to', lora_dir.resolve())\n",
|
| 896 |
+
"\n",
|
| 897 |
+
"summary = {\n",
|
| 898 |
+
" 'model_id' : CONFIG['MODEL_ID'],\n",
|
| 899 |
+
" 'env_url' : CONFIG['ENV_URL'],\n",
|
| 900 |
+
" 'config' : CONFIG,\n",
|
| 901 |
+
" 'sft_loss_history' : sft_loss_history,\n",
|
| 902 |
+
" 'grpo_loss_history' : grpo_loss_history,\n",
|
| 903 |
+
" 'grpo_reward_history' : grpo_reward_history,\n",
|
| 904 |
+
" 'baseline_random' : baseline_random,\n",
|
| 905 |
+
" 'baseline_heuristic' : baseline_heuristic,\n",
|
| 906 |
+
" 'trained_eval_greedy' : trained_eval_greedy,\n",
|
| 907 |
+
" 'trained_eval_sc' : trained_eval_sc,\n",
|
| 908 |
+
" 'improvement_over_random_pct' : (\n",
|
| 909 |
+
" 100.0 * (trained_eval_sc['mean'] - baseline_random['mean'])\n",
|
| 910 |
+
" / max(abs(baseline_random['mean']), 1e-6)\n",
|
| 911 |
+
" ),\n",
|
| 912 |
+
" 'improvement_over_heuristic_pct': (\n",
|
| 913 |
+
" 100.0 * (trained_eval_sc['mean'] - baseline_heuristic['mean'])\n",
|
| 914 |
+
" / max(abs(baseline_heuristic['mean']), 1e-6)\n",
|
| 915 |
+
" ),\n",
|
| 916 |
+
"}\n",
|
| 917 |
+
"sum_path = pathlib.Path(CONFIG['OUT_DIR']) / 'run_summary.json'\n",
|
| 918 |
+
"sum_path.write_text(json.dumps(summary, indent=2, default=float))\n",
|
| 919 |
+
"print('run_summary.json ->', sum_path.resolve())\n",
|
| 920 |
+
"print(f'\\nFinal mean reward — random: {baseline_random[\"mean\"]:.3f} | '\n",
|
| 921 |
+
" f'heuristic: {baseline_heuristic[\"mean\"]:.3f} | '\n",
|
| 922 |
+
" f'trained-greedy: {trained_eval_greedy[\"mean\"]:.3f} | '\n",
|
| 923 |
+
" f'trained-SC: {trained_eval_sc[\"mean\"]:.3f}')\n"
|
| 924 |
+
]
|
| 925 |
+
},
|
| 926 |
+
{
|
| 927 |
+
"cell_type": "markdown",
|
| 928 |
+
"id": "2328ea8a",
|
| 929 |
+
"metadata": {},
|
| 930 |
+
"source": [
|
| 931 |
+
"## What to look for in the results\n",
|
| 932 |
+
"\n",
|
| 933 |
+
"- **DEBUG GATE in cell 16**: `heuristic - random ≥ +0.03`. If it's not, the\n",
|
| 934 |
+
" heuristic teacher is too weak and the run will mirror the previous failure\n",
|
| 935 |
+
" mode (trained < random). Inspect `BIN_BEST_GATEWAY` and try a debug print\n",
|
| 936 |
+
" of `heuristic_policy(obs)` on a few sample observations.\n",
|
| 937 |
+
"- **SFT loss** drops smoothly to <0.3 within one epoch.\n",
|
| 938 |
+
"- **GRPO shaped-reward** trends upward; loss should be small but non-zero\n",
|
| 939 |
+
" (not 1e-6 — that means dead group-relative advantage).\n",
|
| 940 |
+
"- **Mean-reward bar chart**: `Trained-SC ≥ Trained-Greedy ≥ Heuristic > Random`.\n",
|
| 941 |
+
"- **Per-bucket chart**: trained model should at least *match* the heuristic on\n",
|
| 942 |
+
" the easy `low` bucket and beat random/heuristic on `medium`/`high`.\n",
|
| 943 |
+
"\n",
|
| 944 |
+
"### Why the previous run failed (root cause documented for posterity)\n",
|
| 945 |
+
"The risk-only heuristic ignored `BIN_AFFINITY` (the env's dominant reward\n",
|
| 946 |
+
"driver — wrong gateway = 6.7× penalty on `expected_outcome`) and chose\n",
|
| 947 |
+
"`Block` for high risk, which the env *punishes* via `route_score=true_risk`\n",
|
| 948 |
+
"+ forced episode end. Result: heuristic ≈ random on mean reward. SFT cloned\n",
|
| 949 |
+
"this near-random teacher and GRPO with `W_HEURISTIC=0.3` reinforced it →\n",
|
| 950 |
+
"trained < random. Fixed by:\n",
|
| 951 |
+
"\n",
|
| 952 |
+
"1. **BIN-aware heuristic** (encodes `BIN_AFFINITY[gateway][bin_category]`)\n",
|
| 953 |
+
"2. **3DS over Block** (3DS strictly dominates: `eff_fraud_risk *= 0.1` AND\n",
|
| 954 |
+
" the transaction can still succeed)\n",
|
| 955 |
+
"3. **Rebalanced shaped reward** — `W_ENV: 0.5→0.7`, `W_HEURISTIC: 0.3→0.15`\n",
|
| 956 |
+
"4. **Larger eval** — 90 → 300 samples for cleaner mean\n",
|
| 957 |
+
"5. **Sanity gate** that warns when the teacher isn't useful\n",
|
| 958 |
+
"\n",
|
| 959 |
+
"If `Trained-Greedy` is still below `Heuristic` after these fixes:\n",
|
| 960 |
+
"- raise `GRPO_STEPS` to 60+ (the model needs more updates to converge),\n",
|
| 961 |
+
"- raise `SFT_PROMPTS` to 256+ (the BIN→gateway distillation needs coverage).\n"
|
| 962 |
+
]
|
| 963 |
+
}
|
| 964 |
+
],
|
| 965 |
+
"metadata": {
|
| 966 |
+
"kernelspec": {
|
| 967 |
+
"display_name": "Python 3",
|
| 968 |
+
"language": "python",
|
| 969 |
+
"name": "python3"
|
| 970 |
+
},
|
| 971 |
+
"language_info": {
|
| 972 |
+
"name": "python",
|
| 973 |
+
"version": "3.10"
|
| 974 |
+
}
|
| 975 |
+
},
|
| 976 |
+
"nbformat": 4,
|
| 977 |
+
"nbformat_minor": 5
|
| 978 |
+
}
|
notebooks/train_smartpayenev.ipynb
CHANGED
|
@@ -13,50 +13,97 @@
|
|
| 13 |
"\n",
|
| 14 |
"### What's implemented\n",
|
| 15 |
"\n",
|
| 16 |
-
"This notebook implements **true co-evolution** between two learning agents
|
| 17 |
-
"\n",
|
| 18 |
-
"
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
-
"\n",
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"
|
| 28 |
-
"\n",
|
| 29 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
"```\n",
|
| 31 |
"for round in range(N_ROUNDS):\n",
|
| 32 |
-
"
|
| 33 |
-
"
|
| 34 |
-
"
|
| 35 |
-
"
|
|
|
|
|
|
|
| 36 |
"```\n",
|
| 37 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"Why this matters:\n",
|
| 39 |
-
"* Single-step rewards are noisy → **
|
| 40 |
-
"* Different start states per generation → **same-seed group** gives clean
|
| 41 |
"* Static adversary → defender plateaus → **learning fraud agent** keeps pressure escalating.\n",
|
| 42 |
-
"*
|
| 43 |
"\n",
|
| 44 |
"Pipeline:\n",
|
| 45 |
-
"1. Install deps (Unsloth + TRL from GitHub)\n",
|
| 46 |
"2. HF login (uses your HF credits)\n",
|
| 47 |
"3. GPU sanity check + env health\n",
|
| 48 |
-
"4. Build prompt dataset from live `/
|
| 49 |
-
"5.
|
| 50 |
-
"6.
|
| 51 |
-
"7.
|
| 52 |
-
"8.
|
| 53 |
-
"
|
| 54 |
-
"
|
| 55 |
-
"
|
| 56 |
-
"
|
| 57 |
-
"
|
| 58 |
-
"
|
| 59 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
"\n",
|
| 61 |
"Hackathon: OpenEnv (India 2026), Theme #4 — Self-Improvement.\n",
|
| 62 |
"Space: https://huggingface.co/spaces/Pratap-K/SmartPayEnv"
|
|
@@ -72,13 +119,15 @@
|
|
| 72 |
{
|
| 73 |
"cell_type": "code",
|
| 74 |
"execution_count": null,
|
|
|
|
| 75 |
"metadata": {},
|
| 76 |
"outputs": [],
|
| 77 |
"source": [
|
| 78 |
"!pip -q install --upgrade pip\n",
|
| 79 |
"!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
|
|
|
|
| 80 |
"!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
|
| 81 |
-
"!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests
|
| 82 |
]
|
| 83 |
},
|
| 84 |
{
|
|
@@ -121,21 +170,28 @@
|
|
| 121 |
"SEED = 42\n",
|
| 122 |
"\n",
|
| 123 |
"# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
|
| 124 |
-
"# value that still produces all
|
| 125 |
-
"# Approx wall time on a Colab T4: QUICK ~
|
| 126 |
"\n",
|
| 127 |
"# Co-evolution loop\n",
|
| 128 |
-
"N_ROUNDS =
|
| 129 |
"GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
|
| 130 |
"ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
|
| 131 |
"ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
|
| 132 |
"ES_SIGMA = 0.25 # exploration std for ES\n",
|
| 133 |
"ES_LR = 0.4 # ES update rate\n",
|
| 134 |
"\n",
|
| 135 |
-
"# Defender / GRPO
|
| 136 |
"PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
|
| 137 |
"GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
|
| 138 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"\n",
|
| 140 |
"# Final frozen-holdout eval\n",
|
| 141 |
"EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
|
|
@@ -146,10 +202,52 @@
|
|
| 146 |
"COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
|
| 147 |
"COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
|
| 148 |
"\n",
|
| 149 |
-
"
|
| 150 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
"LOAD_IN_4BIT = True\n",
|
| 152 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
"os.makedirs('artifacts', exist_ok=True)\n",
|
| 154 |
"random.seed(SEED)\n",
|
| 155 |
"np.random.seed(SEED)\n",
|
|
@@ -160,6 +258,8 @@
|
|
| 160 |
" '| pop =', ES_POPULATION,\n",
|
| 161 |
" '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
|
| 162 |
" '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
|
|
|
|
|
|
|
| 163 |
" '| MODEL_ID =', MODEL_ID)"
|
| 164 |
]
|
| 165 |
},
|
|
@@ -259,9 +359,16 @@
|
|
| 259 |
" return None\n",
|
| 260 |
"\n",
|
| 261 |
"def rollout_reward(action, seed, difficulty=DIFFICULTY, k=ROLLOUT_STEPS_PER_REWARD):\n",
|
| 262 |
-
" \"\"\"
|
| 263 |
-
"
|
| 264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
" env_reset_seeded(seed, difficulty)\n",
|
| 266 |
" rewards = []\n",
|
| 267 |
" for _ in range(int(k)):\n",
|
|
@@ -332,24 +439,60 @@
|
|
| 332 |
{
|
| 333 |
"cell_type": "code",
|
| 334 |
"execution_count": null,
|
|
|
|
| 335 |
"metadata": {},
|
| 336 |
"outputs": [],
|
| 337 |
"source": [
|
| 338 |
-
"def collect_prompts(n=PROMPT_DATASET_SIZE, difficulty=DIFFICULTY
|
| 339 |
-
"
|
| 340 |
-
"
|
| 341 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
" prompts.append(make_prompt(obs))\n",
|
| 343 |
-
"
|
| 344 |
-
"
|
| 345 |
-
"
|
| 346 |
-
"
|
| 347 |
-
"
|
| 348 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
"\n",
|
| 350 |
-
"
|
| 351 |
-
"
|
| 352 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
]
|
| 354 |
},
|
| 355 |
{
|
|
@@ -362,6 +505,7 @@
|
|
| 362 |
{
|
| 363 |
"cell_type": "code",
|
| 364 |
"execution_count": null,
|
|
|
|
| 365 |
"metadata": {},
|
| 366 |
"outputs": [],
|
| 367 |
"source": [
|
|
@@ -414,6 +558,16 @@
|
|
| 414 |
" fd = 0\n",
|
| 415 |
" return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
|
| 416 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
"baseline_random = eval_policy(random_policy)\n",
|
| 418 |
"baseline_heuristic = eval_policy(heuristic_policy)\n",
|
| 419 |
"print('Random baseline:', baseline_random['mean_reward'], baseline_random['bucket_means'])\n",
|
|
@@ -519,9 +673,50 @@
|
|
| 519 |
" 'best_fraud_fitness': float(np.max(fitnesses)),\n",
|
| 520 |
" }\n",
|
| 521 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
"fraud_agent = FraudPolicy()\n",
|
| 523 |
"fraud_agent.apply()\n",
|
| 524 |
-
"print('Fraud agent initialised with theta =', fraud_agent.theta)"
|
|
|
|
|
|
|
| 525 |
]
|
| 526 |
},
|
| 527 |
{
|
|
@@ -529,31 +724,93 @@
|
|
| 529 |
"id": "5efe6c56",
|
| 530 |
"metadata": {},
|
| 531 |
"source": [
|
| 532 |
-
"## 8. Co-
|
| 533 |
-
"\n",
|
| 534 |
-
"
|
| 535 |
-
"
|
| 536 |
-
"
|
| 537 |
-
"
|
| 538 |
-
"
|
| 539 |
-
"
|
| 540 |
-
"
|
| 541 |
-
"
|
| 542 |
-
"
|
| 543 |
-
"
|
| 544 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
"\n",
|
| 546 |
"Reward signal flow (per defender generation):\n",
|
| 547 |
"```\n",
|
| 548 |
-
"group_seed =
|
| 549 |
"for completion in group:\n",
|
| 550 |
" action = parse_action(completion)\n",
|
| 551 |
-
"
|
|
|
|
| 552 |
"```\n",
|
| 553 |
-
"All `num_generations` completions of one prompt share `group_seed`, so the
|
| 554 |
-
"
|
| 555 |
-
"\n",
|
| 556 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
]
|
| 558 |
},
|
| 559 |
{
|
|
@@ -565,7 +822,7 @@
|
|
| 565 |
"source": [
|
| 566 |
"from unsloth import FastLanguageModel\n",
|
| 567 |
"from datasets import Dataset\n",
|
| 568 |
-
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 569 |
"import hashlib, torch\n",
|
| 570 |
"\n",
|
| 571 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
@@ -574,10 +831,17 @@
|
|
| 574 |
" dtype=None,\n",
|
| 575 |
" load_in_4bit=LOAD_IN_4BIT,\n",
|
| 576 |
")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
"model = FastLanguageModel.get_peft_model(\n",
|
| 578 |
" model,\n",
|
| 579 |
" r=16,\n",
|
| 580 |
-
" target_modules=
|
| 581 |
" lora_alpha=32,\n",
|
| 582 |
" lora_dropout=0.0,\n",
|
| 583 |
" bias='none',\n",
|
|
@@ -586,10 +850,104 @@
|
|
| 586 |
")\n",
|
| 587 |
"if tokenizer.pad_token is None:\n",
|
| 588 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
"\n",
|
| 590 |
"ds = Dataset.from_list([{'prompt': p} for p in prompts])\n",
|
| 591 |
"print(ds)\n",
|
| 592 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
"# ── Reward fn: same-seed group + multi-step rollout ───────────────────\n",
|
| 594 |
"_REWARD_DEBUG = {'calls': 0}\n",
|
| 595 |
"\n",
|
|
@@ -603,18 +961,51 @@
|
|
| 603 |
" return str(comp)\n",
|
| 604 |
"\n",
|
| 605 |
"def _seed_for_prompt(prompt_text):\n",
|
| 606 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
" return int(h[:8], 16) & 0x7FFFFFFF\n",
|
| 608 |
"\n",
|
| 609 |
"def reward_fn(completions, prompts=None, **kwargs):\n",
|
| 610 |
-
" \"\"\"For each completion: parse action,
|
| 611 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
" rewards = []\n",
|
|
|
|
|
|
|
|
|
|
| 613 |
" prompts = prompts or [None] * len(completions)\n",
|
|
|
|
| 614 |
" for prompt_text, comp in zip(prompts, completions):\n",
|
| 615 |
" text = _extract_text(comp)\n",
|
| 616 |
" action = parse_action(text)\n",
|
|
|
|
|
|
|
| 617 |
" seed = _seed_for_prompt(prompt_text or text)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
" try:\n",
|
| 619 |
" r = rollout_reward(action, seed=seed, difficulty=DIFFICULTY,\n",
|
| 620 |
" k=ROLLOUT_STEPS_PER_REWARD)\n",
|
|
@@ -622,17 +1013,34 @@
|
|
| 622 |
" print('reward_fn error:', repr(e))\n",
|
| 623 |
" r = 0.0\n",
|
| 624 |
" rewards.append(float(r))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
" _REWARD_DEBUG['calls'] += 1\n",
|
| 626 |
" if _REWARD_DEBUG['calls'] <= 3:\n",
|
| 627 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
" return rewards\n",
|
| 629 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
"# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
|
| 631 |
-
"#
|
| 632 |
-
"#
|
| 633 |
-
"#
|
| 634 |
-
"
|
| 635 |
-
"
|
| 636 |
"\n",
|
| 637 |
"@torch.no_grad()\n",
|
| 638 |
"def _defender_action(obs):\n",
|
|
@@ -649,6 +1057,18 @@
|
|
| 649 |
" FastLanguageModel.for_training(model)\n",
|
| 650 |
" return parse_action(text)\n",
|
| 651 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
"# ── GRPO config (per-round) ───────────────────────────────────────────\n",
|
| 653 |
"def _make_grpo_cfg(max_steps):\n",
|
| 654 |
" return GRPOConfig(\n",
|
|
@@ -660,12 +1080,13 @@
|
|
| 660 |
" gradient_accumulation_steps=2,\n",
|
| 661 |
" max_steps=int(max_steps),\n",
|
| 662 |
" logging_steps=1,\n",
|
| 663 |
-
" learning_rate=1e-5
|
| 664 |
" save_strategy='no',\n",
|
| 665 |
" report_to=[],\n",
|
| 666 |
-
" bf16
|
| 667 |
-
"
|
| 668 |
-
"
|
|
|
|
| 669 |
" )\n",
|
| 670 |
"\n",
|
| 671 |
"# ── Co-training loop ──────────────────────────────────────────────────\n",
|
|
@@ -675,6 +1096,7 @@
|
|
| 675 |
"fraud_theta_history = [dict(fraud_agent.theta)]\n",
|
| 676 |
"loss_history_all = []\n",
|
| 677 |
"reward_log_all = []\n",
|
|
|
|
| 678 |
"\n",
|
| 679 |
"# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
|
| 680 |
"# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
|
|
@@ -691,17 +1113,278 @@
|
|
| 691 |
" obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
|
| 692 |
" return float(np.mean(rs)) if rs else 0.0\n",
|
| 693 |
"\n",
|
| 694 |
-
"
|
| 695 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
"\n",
|
| 697 |
"for rnd in range(N_ROUNDS):\n",
|
| 698 |
-
"
|
| 699 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
"\n",
|
| 701 |
-
" # Phase A: defender GRPO\n",
|
| 702 |
" cfg = _make_grpo_cfg(max_steps=GRPO_STEPS_PER_ROUND)\n",
|
| 703 |
" trainer = GRPOTrainer(\n",
|
| 704 |
-
" model=model, args=cfg, train_dataset=
|
| 705 |
" processing_class=tokenizer, reward_funcs=[reward_fn],\n",
|
| 706 |
" )\n",
|
| 707 |
" trainer.train()\n",
|
|
@@ -710,36 +1393,78 @@
|
|
| 710 |
" loss_history_all.extend(rnd_loss)\n",
|
| 711 |
" reward_log_all.extend(rnd_rew)\n",
|
| 712 |
"\n",
|
| 713 |
-
" #
|
|
|
|
| 714 |
" def_score = quick_defender_eval()\n",
|
| 715 |
" defender_round_rewards.append(def_score)\n",
|
| 716 |
" print(f' defender mean reward (round {rnd+1}): {def_score:.4f}')\n",
|
| 717 |
"\n",
|
| 718 |
-
" #
|
| 719 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
" round_fraud_fits = []\n",
|
| 721 |
-
"
|
| 722 |
-
"
|
| 723 |
-
"
|
| 724 |
-
"
|
| 725 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 726 |
" fraud_round_fitness.append(float(np.mean(round_fraud_fits)) if round_fraud_fits else 0.0)\n",
|
| 727 |
" fraud_theta_history.append(dict(fraud_agent.theta))\n",
|
| 728 |
"\n",
|
| 729 |
" # Exploitability gap: how much WORSE the defender does against trained\n",
|
| 730 |
-
" # fraud vs. against neutral fraud
|
| 731 |
" env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
|
| 732 |
" baseline_def = quick_defender_eval()\n",
|
| 733 |
-
" fraud_agent.apply()
|
| 734 |
" adv_def = quick_defender_eval()\n",
|
| 735 |
" gap = float(baseline_def - adv_def)\n",
|
| 736 |
" exploitability_log.append(gap)\n",
|
| 737 |
" print(f' exploitability gap: baseline_def={baseline_def:.3f} vs adv_def={adv_def:.3f} -> gap={gap:.3f}')\n",
|
| 738 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
"print('\\nCo-training finished.')\n",
|
|
|
|
|
|
|
|
|
|
| 740 |
"print(' defender_round_rewards:', defender_round_rewards)\n",
|
| 741 |
-
"print(' fraud_round_fitness:
|
| 742 |
-
"print(' exploitability_log:
|
| 743 |
"\n",
|
| 744 |
"# Aliases for downstream cells\n",
|
| 745 |
"loss_history = loss_history_all\n",
|
|
@@ -809,13 +1534,25 @@
|
|
| 809 |
"source": [
|
| 810 |
"import matplotlib.pyplot as plt\n",
|
| 811 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
"# 1. GRPO training reward (across all rounds)\n",
|
| 813 |
"if reward_log:\n",
|
| 814 |
" plt.figure(figsize=(8,4))\n",
|
| 815 |
" plt.plot(reward_log, label='GRPO mean reward per logging step')\n",
|
| 816 |
" plt.xlabel('Logging step (across all defender rounds)')\n",
|
| 817 |
" plt.ylabel('Reward')\n",
|
| 818 |
-
" plt.title('GRPO defender training reward')\n",
|
| 819 |
" plt.legend()\n",
|
| 820 |
" plt.tight_layout()\n",
|
| 821 |
" plt.savefig('artifacts/grpo_reward_curve.png', dpi=140)\n",
|
|
@@ -833,6 +1570,22 @@
|
|
| 833 |
" plt.savefig('artifacts/grpo_training_loss.png', dpi=140)\n",
|
| 834 |
" plt.show()\n",
|
| 835 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
"# 3. Co-evolution: defender reward vs fraud fitness per round\n",
|
| 837 |
"rounds_x = np.arange(1, len(defender_round_rewards) + 1)\n",
|
| 838 |
"fig, ax1 = plt.subplots(figsize=(8,4))\n",
|
|
@@ -875,34 +1628,74 @@
|
|
| 875 |
" plt.savefig('artifacts/fraud_theta_trajectory.png', dpi=140)\n",
|
| 876 |
" plt.show()\n",
|
| 877 |
"\n",
|
| 878 |
-
"# 6. Before vs After\n",
|
| 879 |
-
"
|
| 880 |
-
"
|
| 881 |
-
"
|
| 882 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
"for b, v in zip(bars, values):\n",
|
| 884 |
" plt.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.3f}', ha='center')\n",
|
| 885 |
"plt.ylabel('Mean reward (frozen holdout)')\n",
|
| 886 |
-
"plt.title('Before vs After Training (
|
| 887 |
"plt.tight_layout()\n",
|
| 888 |
"plt.savefig('artifacts/before_after_rewards.png', dpi=140)\n",
|
| 889 |
"plt.show()\n",
|
| 890 |
"\n",
|
| 891 |
-
"#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 892 |
"buckets = ['low', 'medium', 'high']\n",
|
| 893 |
-
"rand_b = [baseline_random['bucket_means'][b]
|
| 894 |
-
"heur_b = [baseline_heuristic['bucket_means'][b]
|
| 895 |
-
"
|
|
|
|
| 896 |
"x = np.arange(len(buckets))\n",
|
| 897 |
-
"w = 0.
|
| 898 |
-
"plt.figure(figsize=(
|
| 899 |
-
"plt.bar(x - w, rand_b, width=w, label='Random',
|
| 900 |
-
"plt.bar(x,
|
| 901 |
-
"plt.bar(x + w,
|
|
|
|
| 902 |
"plt.xticks(x, [b.title()+' Risk' for b in buckets])\n",
|
| 903 |
"plt.ylabel('Mean reward')\n",
|
| 904 |
"plt.title('Per Risk-Bucket Reward (frozen holdout)')\n",
|
| 905 |
-
"plt.legend()\n",
|
| 906 |
"plt.tight_layout()\n",
|
| 907 |
"plt.savefig('artifacts/per_bucket_rewards.png', dpi=140)\n",
|
| 908 |
"plt.show()\n",
|
|
@@ -912,21 +1705,36 @@
|
|
| 912 |
" 'model_id': MODEL_ID,\n",
|
| 913 |
" 'quick_mode': QUICK_MODE,\n",
|
| 914 |
" 'prompts_used': len(prompts),\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 915 |
" 'grpo_num_generations': GRPO_NUM_GENERATIONS,\n",
|
| 916 |
" 'rollout_steps_per_reward': ROLLOUT_STEPS_PER_REWARD,\n",
|
| 917 |
" 'n_rounds': N_ROUNDS,\n",
|
| 918 |
" 'grpo_steps_per_round': GRPO_STEPS_PER_ROUND,\n",
|
| 919 |
" 'es_steps_per_round': ES_STEPS_PER_ROUND,\n",
|
| 920 |
" 'es_population': ES_POPULATION,\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
" 'baseline_random_mean_reward': baseline_random['mean_reward'],\n",
|
| 922 |
" 'baseline_heuristic_mean_reward': baseline_heuristic['mean_reward'],\n",
|
| 923 |
-
" '
|
| 924 |
-
" '
|
| 925 |
-
" '
|
|
|
|
| 926 |
" 'per_bucket': {\n",
|
| 927 |
-
" 'random':
|
| 928 |
-
" 'heuristic':
|
| 929 |
-
" '
|
|
|
|
| 930 |
" },\n",
|
| 931 |
" 'defender_round_rewards': defender_round_rewards,\n",
|
| 932 |
" 'fraud_round_fitness': fraud_round_fitness,\n",
|
|
@@ -936,9 +1744,10 @@
|
|
| 936 |
" 'grpo_reward_curve': reward_log,\n",
|
| 937 |
" 'grpo_loss_history': loss_history,\n",
|
| 938 |
" 'eval_per_episode': {\n",
|
| 939 |
-
" 'random':
|
| 940 |
-
" 'heuristic':
|
| 941 |
-
" '
|
|
|
|
| 942 |
" },\n",
|
| 943 |
"}\n",
|
| 944 |
"with open('artifacts/run_summary.json', 'w', encoding='utf-8') as f:\n",
|
|
|
|
| 13 |
"\n",
|
| 14 |
"### What's implemented\n",
|
| 15 |
"\n",
|
| 16 |
+
"This notebook implements **true co-evolution** between two learning agents,\n",
|
| 17 |
+
"trained in **two stages** with a **curriculum ladder + PFSP league** to keep\n",
|
| 18 |
+
"RL stable:\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"**Stage 1 — SFT warm-start.** The defender LoRA is first SFT'd on\n",
|
| 21 |
+
"`(prompt → heuristic_action)` pairs so the model learns the JSON output format\n",
|
| 22 |
+
"and the basic risk→action prior. Without this, GRPO from a cold base model gets\n",
|
| 23 |
+
"a flat reward curve and a near-zero loss (no advantage signal between\n",
|
| 24 |
+
"completions in a group).\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"**Stage 2 — Ladder co-evolution (GRPO ⇄ ES + League).**\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"* **Defender LLM** — `unsloth/phi-3-mini-4k-instruct-bnb-4bit` (LoRA) trained\n",
|
| 29 |
+
" with **TRL GRPO** on Unsloth (4-bit base, fp16 LoRA — no `bf16` so it runs on\n",
|
| 30 |
+
" Colab T4 which has no bf16 support).\n",
|
| 31 |
+
" Reward comes from a deterministic **K-step rollout** in the env (not a single\n",
|
| 32 |
+
" noisy step). All `num_generations` completions in a GRPO group share the\n",
|
| 33 |
+
" **same seed** (via `/reset_seeded`) AND the prompts are **refreshed each round\n",
|
| 34 |
+
" under the current adversary** so prompt-obs and reward-obs are always aligned.\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"* **Fraud agent** — a parametric policy with 3 continuous parameters\n",
|
| 37 |
+
" (`intensity`, `noise_boost`, `pattern_rate`) updated by **Evolution Strategies (ES)**\n",
|
| 38 |
+
" and *anchored* to one of three ladder rungs (easy / medium / hard).\n",
|
| 39 |
+
" *Optional upgrade*: set `USE_LLM_FRAUD=True` in cell 6 to swap the ES\n",
|
| 40 |
+
" policy for a **second LoRA on the same Phi-3 base** — a true dual-LLM\n",
|
| 41 |
+
" self-play setup where the fraud LoRA is GRPO-trained to OUTPUT adversary\n",
|
| 42 |
+
" parameter JSON (reward = `1 - defender_reward`). Default OFF so QUICK\n",
|
| 43 |
+
" stays fast; flip ON for the upgraded recipe at ~1.5× wall time and\n",
|
| 44 |
+
" ~2× base-model VRAM.\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"* **LADDER + LEAGUE (research-backed stability fix).** Pure ES drift is unstable\n",
|
| 47 |
+
" — the defender catastrophically forgets early attack regimes once fraud-θ\n",
|
| 48 |
+
" drifts. We solve this with:\n",
|
| 49 |
+
" 1. **Curriculum rungs** (`LADDER_RUNGS`): the round schedule promotes the\n",
|
| 50 |
+
" fraud anchor easy → medium → hard, so the defender masters each regime\n",
|
| 51 |
+
" before the next.\n",
|
| 52 |
+
" 2. **PFSP league pool** (`LeagueLadder`): every settled rung's fraud-θ is\n",
|
| 53 |
+
" snapshotted into a pool. During ES, with prob `LEAGUE_PAST_SAMPLE_PROB`\n",
|
| 54 |
+
" a candidate is evaluated against a sampled *past* rung instead of the\n",
|
| 55 |
+
" current one — keeping pressure across the whole observed difficulty.\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"Co-training loop (per round):\n",
|
| 58 |
"```\n",
|
| 59 |
"for round in range(N_ROUNDS):\n",
|
| 60 |
+
" rung = LADDER_RUNGS[ rung_for_round(round) ] # easy → medium → hard\n",
|
| 61 |
+
" fraud_agent.theta = rung_anchor # ladder anchor\n",
|
| 62 |
+
" refresh_prompts_under_current_adversary() # FIX B: prompt/reward alignment\n",
|
| 63 |
+
" train_defender_GRPO(K_step_rollout, same_seed_per_group)\n",
|
| 64 |
+
" league.add(fraud_agent.theta) # snapshot rung\n",
|
| 65 |
+
" ES_step_with_PFSP_past_sampling(defender) # LeagueLadder.sample\n",
|
| 66 |
"```\n",
|
| 67 |
"\n",
|
| 68 |
+
"Critical alignment & stability fixes baked in:\n",
|
| 69 |
+
"* **FIX A** — adversary is reset to NEUTRAL before baseline eval so Random /\n",
|
| 70 |
+
" Heuristic numbers are not poisoned by leftover state from a previous run.\n",
|
| 71 |
+
"* **FIX B** — prompts are re-collected at the start of every round under the\n",
|
| 72 |
+
" CURRENT adversary so `env_reset_seeded(seed)` reproduces the EXACT obs the\n",
|
| 73 |
+
" prompt was made from. Without this, ES drift would silently misalign the\n",
|
| 74 |
+
" GRPO gradient.\n",
|
| 75 |
+
"* **FIX C** — multi-step rollout (`K=3`) reduces single-step reward variance\n",
|
| 76 |
+
" and trains the model on the immediate downstream consequences (chargebacks,\n",
|
| 77 |
+
" anti-gaming alerts) that matter at episode-eval time.\n",
|
| 78 |
+
"* **FIX D** — the bar plot now shows BOTH \"Trained vs Neutral\" (apples-to-apples\n",
|
| 79 |
+
" with baselines) AND \"Trained vs Co-evolved\" (robustness on the hardest fraud).\n",
|
| 80 |
+
"\n",
|
| 81 |
"Why this matters:\n",
|
| 82 |
+
"* Single-step rewards are noisy → **K-step rollout** kills variance.\n",
|
| 83 |
+
"* Different start states per generation → **same-seed group** gives clean advantages.\n",
|
| 84 |
"* Static adversary → defender plateaus → **learning fraud agent** keeps pressure escalating.\n",
|
| 85 |
+
"* Pure ES drift → catastrophic forgetting → **ladder rungs + PFSP league** stabilise it.\n",
|
| 86 |
"\n",
|
| 87 |
"Pipeline:\n",
|
| 88 |
+
"1. Install deps (Unsloth + Unsloth-Zoo + TRL from GitHub)\n",
|
| 89 |
"2. HF login (uses your HF credits)\n",
|
| 90 |
"3. GPU sanity check + env health\n",
|
| 91 |
+
"4. Build prompt + obs dataset from live `/reset_seeded` calls\n",
|
| 92 |
+
"5. **FIX A**: reset adversary to neutral, then baseline eval (random + heuristic)\n",
|
| 93 |
+
"6. Initialise FraudPolicy + LeagueLadder\n",
|
| 94 |
+
"7. **Stage 1: SFT warm-start** on heuristic-labeled (prompt, action) pairs\n",
|
| 95 |
+
"8. **Stage 2: Ladder co-training loop** — rung curriculum + GRPO defender + ES fraud + league\n",
|
| 96 |
+
"9. Trained-policy eval (vs co-evolved fraud AND vs neutral fraud)\n",
|
| 97 |
+
"10. Plots:\n",
|
| 98 |
+
" - SFT warm-start loss\n",
|
| 99 |
+
" - GRPO training reward + loss\n",
|
| 100 |
+
" - Defender mean reward per round\n",
|
| 101 |
+
" - Fraud agent mean fitness per round\n",
|
| 102 |
+
" - Exploitability gap per round\n",
|
| 103 |
+
" - Fraud parameter trajectories\n",
|
| 104 |
+
" - **FIX D**: Before vs After (4 bars: Random / Heuristic / Trained-neutral / Trained-coevolved)\n",
|
| 105 |
+
" - **FIX D**: Per risk-bucket reward (4 bars × 3 buckets)\n",
|
| 106 |
+
"11. Save artifacts to `./artifacts` (incl. ladder rung schedule + league pool)\n",
|
| 107 |
"\n",
|
| 108 |
"Hackathon: OpenEnv (India 2026), Theme #4 — Self-Improvement.\n",
|
| 109 |
"Space: https://huggingface.co/spaces/Pratap-K/SmartPayEnv"
|
|
|
|
| 119 |
{
|
| 120 |
"cell_type": "code",
|
| 121 |
"execution_count": null,
|
| 122 |
+
"id": "177bf9d5",
|
| 123 |
"metadata": {},
|
| 124 |
"outputs": [],
|
| 125 |
"source": [
|
| 126 |
"!pip -q install --upgrade pip\n",
|
| 127 |
"!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n",
|
| 128 |
+
"!pip -q install \"unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git\"\n",
|
| 129 |
"!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n",
|
| 130 |
+
"!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests"
|
| 131 |
]
|
| 132 |
},
|
| 133 |
{
|
|
|
|
| 170 |
"SEED = 42\n",
|
| 171 |
"\n",
|
| 172 |
"# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
|
| 173 |
+
"# value that still produces all plots + meaningful accuracy comparison.\n",
|
| 174 |
+
"# Approx wall time on a Colab T4: QUICK ~5-7 min, FULL ~15-22 min.\n",
|
| 175 |
"\n",
|
| 176 |
"# Co-evolution loop\n",
|
| 177 |
+
"N_ROUNDS = 3 if QUICK_MODE else 6 # >=3 so the ladder visits >=2 rungs\n",
|
| 178 |
"GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
|
| 179 |
"ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
|
| 180 |
"ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
|
| 181 |
"ES_SIGMA = 0.25 # exploration std for ES\n",
|
| 182 |
"ES_LR = 0.4 # ES update rate\n",
|
| 183 |
"\n",
|
| 184 |
+
"# Defender / GRPO\n",
|
| 185 |
"PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
|
| 186 |
"GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
|
| 187 |
+
"# K=3 multi-step rollout: with the per-round prompt refresh (Fix B) the env's\n",
|
| 188 |
+
"# adversary config matches the obs the prompt was generated from, so K\n",
|
| 189 |
+
"# subsequent deterministic steps are well-defined. K>1 here reduces single-\n",
|
| 190 |
+
"# step reward variance and trains the model to pick actions that are also\n",
|
| 191 |
+
"# robust to the immediate downstream consequences (chargebacks, anti-gaming\n",
|
| 192 |
+
"# alerts) which matter at episode-eval time. Don't push K higher in QUICK\n",
|
| 193 |
+
"# (each generation costs K env round-trips).\n",
|
| 194 |
+
"ROLLOUT_STEPS_PER_REWARD = 3 if QUICK_MODE else 4\n",
|
| 195 |
"\n",
|
| 196 |
"# Final frozen-holdout eval\n",
|
| 197 |
"EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
|
|
|
|
| 202 |
"COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
|
| 203 |
"COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
|
| 204 |
"\n",
|
| 205 |
+
"# Token budgets (bumped after diagnosing prompt right-truncation dropping the\n",
|
| 206 |
+
"# schema instruction, and completion truncation cutting valid JSON mid-string).\n",
|
| 207 |
+
"DEF_MAX_PROMPT_TOKENS = 1024 if QUICK_MODE else 1536\n",
|
| 208 |
+
"DEF_MAX_NEW_TOKENS = 64 if QUICK_MODE else 96\n",
|
| 209 |
+
"\n",
|
| 210 |
+
"MODEL_ID = 'unsloth/phi-3-mini-4k-instruct-bnb-4bit'\n",
|
| 211 |
+
"MAX_SEQ_LEN = 2048 # ample for prompt + completion in both modes (phi-3 supports 4k)\n",
|
| 212 |
"LOAD_IN_4BIT = True\n",
|
| 213 |
"\n",
|
| 214 |
+
"# Disjoint seed range for training prompts so it never collides with eval seeds\n",
|
| 215 |
+
"# (10_000+ for fraud-vs-defender, 20_000+ for quick eval). The PROMPT_BASE_SEED\n",
|
| 216 |
+
"# is offset per round so each round's prompt set is fresh under the new adversary.\n",
|
| 217 |
+
"PROMPT_BASE_SEED = 1_000_000\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"# ── Curriculum LADDER (PFSP-style league of fraud rungs) ─────────────\n",
|
| 220 |
+
"# Each rung is an anchor (intensity, noise_boost, pattern_rate) for the fraud\n",
|
| 221 |
+
"# agent. The defender starts at rung 0 (easy fraud) and climbs as rounds\n",
|
| 222 |
+
"# progress. ES still explores LOCALLY around each rung's anchor, so within a\n",
|
| 223 |
+
"# rung fraud gets harder against the current defender, then promotes. This\n",
|
| 224 |
+
"# is the curriculum-learning analogue of Fictitious-Self-Play: by keeping\n",
|
| 225 |
+
"# the *anchor* explicit, defender doesn't catastrophically forget early\n",
|
| 226 |
+
"# attack regimes when ES drifts the adversary too far. A snapshot of each\n",
|
| 227 |
+
"# settled fraud-θ is saved into the LeagueLadder pool (cell 16), and a\n",
|
| 228 |
+
"# fraction of ES evals are done against a sampled past rung to prevent\n",
|
| 229 |
+
"# the defender from being \"tutored\" by an unrealistically easy current rung.\n",
|
| 230 |
+
"LADDER_RUNGS = [\n",
|
| 231 |
+
" {'intensity': 1.0, 'noise_boost': 0.05, 'pattern_rate': 0.15}, # rung 0: easy\n",
|
| 232 |
+
" {'intensity': 1.3, 'noise_boost': 0.18, 'pattern_rate': 0.35}, # rung 1: medium\n",
|
| 233 |
+
" {'intensity': 1.7, 'noise_boost': 0.32, 'pattern_rate': 0.55}, # rung 2: hard\n",
|
| 234 |
+
"]\n",
|
| 235 |
+
"LEAGUE_PAST_SAMPLE_PROB = 0.3 # P(ES eval against a past rung instead of current)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"# ── OPTIONAL: dual-LoRA fraud LLM (truly two-LLM self-play) ──────────\n",
|
| 238 |
+
"# When True, a SECOND LoRA on the same Phi-3 base is trained to PROPOSE\n",
|
| 239 |
+
"# adversary parameters (intensity / noise_boost / pattern_rate) via GRPO,\n",
|
| 240 |
+
"# replacing the parametric ES fraud agent inside the co-training loop.\n",
|
| 241 |
+
"# Default OFF so QUICK_MODE stays fast (2x base-model VRAM and ~1.5x wall\n",
|
| 242 |
+
"# time when ON). Both LoRAs share the same MODEL_ID.\n",
|
| 243 |
+
"USE_LLM_FRAUD = False\n",
|
| 244 |
+
"FRAUD_GRPO_STEPS_PER_ROUND = 2 if QUICK_MODE else 8\n",
|
| 245 |
+
"FRAUD_PROMPT_DATASET_SIZE = 8 if QUICK_MODE else 32\n",
|
| 246 |
+
"FRAUD_GRPO_NUM_GENERATIONS = 3 if QUICK_MODE else 4\n",
|
| 247 |
+
"FRAUD_MAX_PROMPT_TOKENS = 512 if QUICK_MODE else 768\n",
|
| 248 |
+
"FRAUD_MAX_NEW_TOKENS = 48\n",
|
| 249 |
+
"FRAUD_LORA_R = 8 # smaller than defender (smaller search space)\n",
|
| 250 |
+
"\n",
|
| 251 |
"os.makedirs('artifacts', exist_ok=True)\n",
|
| 252 |
"random.seed(SEED)\n",
|
| 253 |
"np.random.seed(SEED)\n",
|
|
|
|
| 258 |
" '| pop =', ES_POPULATION,\n",
|
| 259 |
" '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
|
| 260 |
" '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
|
| 261 |
+
" '| LADDER rungs =', len(LADDER_RUNGS),\n",
|
| 262 |
+
" '| USE_LLM_FRAUD =', USE_LLM_FRAUD,\n",
|
| 263 |
" '| MODEL_ID =', MODEL_ID)"
|
| 264 |
]
|
| 265 |
},
|
|
|
|
| 359 |
" return None\n",
|
| 360 |
"\n",
|
| 361 |
"def rollout_reward(action, seed, difficulty=DIFFICULTY, k=ROLLOUT_STEPS_PER_REWARD):\n",
|
| 362 |
+
" \"\"\"Score `action` on the *exact* obs that `seed` reproduces.\n",
|
| 363 |
+
"\n",
|
| 364 |
+
" Critical: `seed` MUST come from PROMPT_TO_SEED (set up in cell 12) so that\n",
|
| 365 |
+
" env_reset_seeded(seed) regenerates the SAME transaction whose obs is in the\n",
|
| 366 |
+
" prompt. The first env_step then scores the action on THAT obs — the only\n",
|
| 367 |
+
" way GRPO's reward can be correlated with the prompt the model saw.\n",
|
| 368 |
+
"\n",
|
| 369 |
+
" K=1 is the semantically correct default. K>1 averages across SUBSEQUENT\n",
|
| 370 |
+
" transactions whose optimal action differs, which dilutes the signal. The\n",
|
| 371 |
+
" parameter is kept for backward compat / variance experimentation only.\"\"\"\n",
|
| 372 |
" env_reset_seeded(seed, difficulty)\n",
|
| 373 |
" rewards = []\n",
|
| 374 |
" for _ in range(int(k)):\n",
|
|
|
|
| 439 |
{
|
| 440 |
"cell_type": "code",
|
| 441 |
"execution_count": null,
|
| 442 |
+
"id": "0b9f60c5",
|
| 443 |
"metadata": {},
|
| 444 |
"outputs": [],
|
| 445 |
"source": [
|
| 446 |
+
"def collect_prompts(n=PROMPT_DATASET_SIZE, difficulty=DIFFICULTY,\n",
|
| 447 |
+
" base_seed=PROMPT_BASE_SEED):\n",
|
| 448 |
+
" \"\"\"Collect (seed, prompt, obs) triples using *deterministic* seeded resets.\n",
|
| 449 |
+
"\n",
|
| 450 |
+
" Each prompt i is generated by `env_reset_seeded(seed=base_seed+i)`, so the\n",
|
| 451 |
+
" same call later in `rollout_reward` reproduces the EXACT same obs. This is\n",
|
| 452 |
+
" what makes GRPO's reward correlated with the prompt — without it, the env\n",
|
| 453 |
+
" is reset to an unrelated state and the gradient is essentially noise.\n",
|
| 454 |
+
" \"\"\"\n",
|
| 455 |
+
" prompts, obs_list, seeds = [], [], []\n",
|
| 456 |
+
" for i in range(int(n)):\n",
|
| 457 |
+
" s = int(base_seed + i)\n",
|
| 458 |
+
" obs = env_reset_seeded(seed=s, difficulty=difficulty)\n",
|
| 459 |
" prompts.append(make_prompt(obs))\n",
|
| 460 |
+
" obs_list.append(copy.deepcopy(obs))\n",
|
| 461 |
+
" seeds.append(s)\n",
|
| 462 |
+
" return prompts, obs_list, seeds\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"prompts, prompt_obs, prompt_seeds = collect_prompts()\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"# ── prompt → seed lookup (keyed on the obs JSON, NOT the full prompt string) ──\n",
|
| 467 |
+
"# We key on the obs JSON only, so even if TRL wraps the prompt in a chat\n",
|
| 468 |
+
"# template or alters whitespace, the lookup still hits.\n",
|
| 469 |
+
"import re as _re\n",
|
| 470 |
+
"_OBS_JSON_RE = _re.compile(\n",
|
| 471 |
+
" r'SmartPayEnv observation:\\n(\\{.*?\\})\\nReturn one action JSON',\n",
|
| 472 |
+
" _re.DOTALL,\n",
|
| 473 |
+
")\n",
|
| 474 |
"\n",
|
| 475 |
+
"def _obs_key(prompt_text):\n",
|
| 476 |
+
" m = _OBS_JSON_RE.search(prompt_text or '')\n",
|
| 477 |
+
" return m.group(1) if m else (prompt_text or '')\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"PROMPT_TO_SEED = {_obs_key(p): s for p, s in zip(prompts, prompt_seeds)}\n",
|
| 480 |
+
"PROMPT_TO_OBS = {_obs_key(p): o for p, o in zip(prompts, prompt_obs)}\n",
|
| 481 |
+
"\n",
|
| 482 |
+
"print('Prompts collected:', len(prompts),\n",
|
| 483 |
+
" '| obs cached:', len(prompt_obs),\n",
|
| 484 |
+
" '| seed lookup entries:', len(PROMPT_TO_SEED))\n",
|
| 485 |
+
"print('Example prompt:\\n', prompts[0][:300], '...')\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"# Sanity: round-trip the first prompt through the env to confirm the seeded\n",
|
| 488 |
+
"# reset really does reproduce the obs in the prompt.\n",
|
| 489 |
+
"_check_obs = env_reset_seeded(seed=prompt_seeds[0], difficulty=DIFFICULTY)\n",
|
| 490 |
+
"_orig = prompt_obs[0]\n",
|
| 491 |
+
"_match_keys = ['amount', 'merchant_category', 'observed_fraud_risk',\n",
|
| 492 |
+
" 'time_of_day', 'transaction_velocity']\n",
|
| 493 |
+
"_ok = all(_check_obs.get(k) == _orig.get(k) for k in _match_keys)\n",
|
| 494 |
+
"print(f' seed→obs reproducibility check on {_match_keys}: '\n",
|
| 495 |
+
" f'{\"OK\" if _ok else \"MISMATCH (alignment fix will not help!)\"}')"
|
| 496 |
]
|
| 497 |
},
|
| 498 |
{
|
|
|
|
| 505 |
{
|
| 506 |
"cell_type": "code",
|
| 507 |
"execution_count": null,
|
| 508 |
+
"id": "89f1d935",
|
| 509 |
"metadata": {},
|
| 510 |
"outputs": [],
|
| 511 |
"source": [
|
|
|
|
| 558 |
" fd = 0\n",
|
| 559 |
" return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n",
|
| 560 |
"\n",
|
| 561 |
+
"# ── FIX A — Reset env adversary to NEUTRAL before measuring baselines ──\n",
|
| 562 |
+
"# The HF Space is a long-running server: previous runs leave the adversary\n",
|
| 563 |
+
"# at hard settings (e.g. intensity=1.8, noise=0.4 from a finished co-evolution\n",
|
| 564 |
+
"# loop), which silently penalises the heuristic baseline of any subsequent\n",
|
| 565 |
+
"# run and makes the bar chart misleading. We pin the adversary to a defined\n",
|
| 566 |
+
"# neutral state here so baselines are reproducible across runs and directly\n",
|
| 567 |
+
"# comparable with `trained_eval_neutral` later.\n",
|
| 568 |
+
"print('[FIX A] Resetting adversary to neutral before baseline eval...')\n",
|
| 569 |
+
"env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
|
| 570 |
+
"\n",
|
| 571 |
"baseline_random = eval_policy(random_policy)\n",
|
| 572 |
"baseline_heuristic = eval_policy(heuristic_policy)\n",
|
| 573 |
"print('Random baseline:', baseline_random['mean_reward'], baseline_random['bucket_means'])\n",
|
|
|
|
| 673 |
" 'best_fraud_fitness': float(np.max(fitnesses)),\n",
|
| 674 |
" }\n",
|
| 675 |
"\n",
|
| 676 |
+
"class LeagueLadder:\n",
|
| 677 |
+
" \"\"\"A pool of past fraud-θ snapshots, one per settled rung.\n",
|
| 678 |
+
"\n",
|
| 679 |
+
" Inspired by AlphaStar's PFSP league. We use the league for **two**\n",
|
| 680 |
+
" correctly-typed purposes:\n",
|
| 681 |
+
"\n",
|
| 682 |
+
" 1. **Defender-side rehearsal** (during prompt refresh): with probability\n",
|
| 683 |
+
" `LEAGUE_PAST_SAMPLE_PROB` we collect this round's prompts under a\n",
|
| 684 |
+
" sampled PAST rung instead of the current rung. This forces the\n",
|
| 685 |
+
" defender's GRPO gradient to occasionally include earlier attack\n",
|
| 686 |
+
" regimes — preventing catastrophic forgetting as the ladder climbs.\n",
|
| 687 |
+
"\n",
|
| 688 |
+
" 2. **Final robustness telemetry**: at the end of training we measure the\n",
|
| 689 |
+
" trained defender against EVERY rung in the league. A robust policy\n",
|
| 690 |
+
" scores well on all rungs; an over-fit one only scores well on the\n",
|
| 691 |
+
" last. This is plotted in cell 22.\n",
|
| 692 |
+
"\n",
|
| 693 |
+
" NOTE: We deliberately do NOT mix past rungs into the fraud-ES gradient.\n",
|
| 694 |
+
" Doing so credits the candidate-θ perturbation with fitness measured\n",
|
| 695 |
+
" against an unrelated past θ, which adds noise to the ES estimate\n",
|
| 696 |
+
" instead of useful signal. Defender rehearsal is the correct place.\n",
|
| 697 |
+
" \"\"\"\n",
|
| 698 |
+
" def __init__(self):\n",
|
| 699 |
+
" self.rungs = [] # list of {'name': str, 'theta': dict}\n",
|
| 700 |
+
" def add(self, name, theta):\n",
|
| 701 |
+
" self.rungs.append({'name': str(name), 'theta': dict(theta)})\n",
|
| 702 |
+
" def sample_past(self):\n",
|
| 703 |
+
" \"\"\"Uniformly sample a strictly-past rung. League is updated *after*\n",
|
| 704 |
+
" GRPO at the end of each round, so at prompt-refresh time the league\n",
|
| 705 |
+
" already contains only past rounds — no exclusion needed. Returns\n",
|
| 706 |
+
" None if the league is empty (round 1).\"\"\"\n",
|
| 707 |
+
" if not self.rungs:\n",
|
| 708 |
+
" return None\n",
|
| 709 |
+
" return dict(random.choice(self.rungs)['theta'])\n",
|
| 710 |
+
" def __len__(self):\n",
|
| 711 |
+
" return len(self.rungs)\n",
|
| 712 |
+
"\n",
|
| 713 |
+
"league = LeagueLadder()\n",
|
| 714 |
+
"\n",
|
| 715 |
"fraud_agent = FraudPolicy()\n",
|
| 716 |
"fraud_agent.apply()\n",
|
| 717 |
+
"print('Fraud agent initialised with theta =', fraud_agent.theta)\n",
|
| 718 |
+
"print(f'League ladder ready (rungs configured: {len(LADDER_RUNGS)}, '\n",
|
| 719 |
+
" f'past-rehearsal prob: {LEAGUE_PAST_SAMPLE_PROB})')"
|
| 720 |
]
|
| 721 |
},
|
| 722 |
{
|
|
|
|
| 724 |
"id": "5efe6c56",
|
| 725 |
"metadata": {},
|
| 726 |
"source": [
|
| 727 |
+
"## 8. SFT warm-start → Ladder Co-evolution (GRPO defender ⇄ ES fraud + League)\n",
|
| 728 |
+
"\n",
|
| 729 |
+
"GRPO from a *cold* base model gives a flat reward curve: the policy doesn't yet\n",
|
| 730 |
+
"emit valid action JSON, so all completions in a group earn nearly the same\n",
|
| 731 |
+
"reward → zero group-relative advantage → zero gradient (loss collapses to ~1e-6).\n",
|
| 732 |
+
"\n",
|
| 733 |
+
"Even after SFT solves that, pure ES on the fraud agent introduces a *second*\n",
|
| 734 |
+
"failure mode: fraud-θ drifts arbitrarily, the defender catastrophically forgets\n",
|
| 735 |
+
"how to handle earlier attack regimes, and the eval bar chart shows the trained\n",
|
| 736 |
+
"LLM losing to baselines on the hardest risk bucket. We solve this with a\n",
|
| 737 |
+
"**ladder + league** wrapped around the two-stage training.\n",
|
| 738 |
+
"\n",
|
| 739 |
+
"**Stage 1: SFT warm-start (heuristic imitation)**\n",
|
| 740 |
+
"Label each cached prompt with the *heuristic* action (`risk_bucket → Block /\n",
|
| 741 |
+
"3DS / Allow + best gateway`) and run a short SFT pass. After this the model:\n",
|
| 742 |
+
"- emits parseable JSON ~100% of the time,\n",
|
| 743 |
+
"- already beats random,\n",
|
| 744 |
+
"- gives GRPO a *non-degenerate* starting policy with reward variance.\n",
|
| 745 |
+
"\n",
|
| 746 |
+
"**Stage 2: Ladder co-evolution (per round)**\n",
|
| 747 |
+
"1. **Pick rung.** `_rung_for_round(rnd)` selects a `LADDER_RUNGS` anchor\n",
|
| 748 |
+
" (easy / medium / hard). On rung change, fraud-θ is reset to that anchor —\n",
|
| 749 |
+
" ES then explores LOCALLY around it instead of drifting arbitrarily.\n",
|
| 750 |
+
"2. **Refresh prompts (Fix B).** Re-collect the prompt set under the *current*\n",
|
| 751 |
+
" adversary so prompt-obs and reward-obs match exactly inside this round's\n",
|
| 752 |
+
" GRPO. Without this, prompts made under rung k-1 are silently scored under\n",
|
| 753 |
+
" rung k (different intensity/noise → different obs from the same seed) and\n",
|
| 754 |
+
" the GRPO gradient is misaligned.\n",
|
| 755 |
+
"3. **Defender phase (GRPO).** `GRPO_STEPS_PER_ROUND` gradient steps. Reward\n",
|
| 756 |
+
" for each completion is a **K-step rollout** with a **shared seed** across\n",
|
| 757 |
+
" the whole group → clean group-relative advantage.\n",
|
| 758 |
+
"4. **Snapshot to league.** Save fraud-θ for this rung into `LeagueLadder`.\n",
|
| 759 |
+
"5. **Fraud phase (ES + PFSP).** ES updates push fraud-θ toward perturbations\n",
|
| 760 |
+
" that *lower* defender reward — but with prob `LEAGUE_PAST_SAMPLE_PROB` a\n",
|
| 761 |
+
" candidate is evaluated against a sampled past rung instead of the current\n",
|
| 762 |
+
" one, preventing over-fit to the latest anchor.\n",
|
| 763 |
"\n",
|
| 764 |
"Reward signal flow (per defender generation):\n",
|
| 765 |
"```\n",
|
| 766 |
+
"group_seed = PROMPT_TO_SEED[obs_in_prompt] # round-local cached seed\n",
|
| 767 |
"for completion in group:\n",
|
| 768 |
" action = parse_action(completion)\n",
|
| 769 |
+
" /reset_seeded(group_seed) # reproduces THE EXACT obs in the prompt\n",
|
| 770 |
+
" reward = mean( /step(action) for k in K ) # K=3 deterministic rollout\n",
|
| 771 |
"```\n",
|
| 772 |
+
"All `num_generations` completions of one prompt share `group_seed`, so the env\n",
|
| 773 |
+
"is reset to the *same* starting obs for every completion — exactly the obs the\n",
|
| 774 |
+
"model saw in its prompt. The only thing varying inside a group is the action,\n",
|
| 775 |
+
"exactly what GRPO needs for a clean group-relative advantage.\n",
|
| 776 |
+
"\n",
|
| 777 |
+
"**Why prompt refresh + ladder anchors are critical:** previously prompts were\n",
|
| 778 |
+
"collected ONCE before the loop, but ES then changed the adversary every round.\n",
|
| 779 |
+
"`env_reset_seeded(seed)` produces a different obs once `_adv_intensity` /\n",
|
| 780 |
+
"`_adv_noise_boost` change, so the obs inside the prompt and the obs the action\n",
|
| 781 |
+
"was scored against drifted apart. Refreshing prompts each round + anchoring\n",
|
| 782 |
+
"fraud to a discrete rung kills both the alignment bug AND the ES-drift\n",
|
| 783 |
+
"forgetting problem at once.\n",
|
| 784 |
+
"\n",
|
| 785 |
+
"**Token budgets** are sized so that:\n",
|
| 786 |
+
"- The schema instruction at the END of the prompt is never truncated\n",
|
| 787 |
+
" (`tokenizer.truncation_side='left'` drops the legend at the front instead).\n",
|
| 788 |
+
"- The completion JSON fits comfortably even if the model writes a short\n",
|
| 789 |
+
" prose prefix.\n",
|
| 790 |
+
"\n",
|
| 791 |
+
"No `/simulate` is used anywhere. No `bf16` (T4 has no bf16 support; Unsloth\n",
|
| 792 |
+
"auto-picks fp16 for the 4-bit base + LoRA).\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"### Optional: dual-LoRA fraud LLM (`USE_LLM_FRAUD = True`)\n",
|
| 795 |
+
"\n",
|
| 796 |
+
"When the flag is on, a SECOND LoRA on the same Phi-3 base is trained alongside\n",
|
| 797 |
+
"the defender. Its prompt summarises the current matchup (rung + current θ +\n",
|
| 798 |
+
"last defender reward) and it must emit a JSON proposal of (intensity,\n",
|
| 799 |
+
"noise_boost, pattern_rate). Reward = `1 - defender_reward` evaluated under the\n",
|
| 800 |
+
"proposed θ, so GRPO's group-relative advantage rewards proposals the current\n",
|
| 801 |
+
"defender is weakest against.\n",
|
| 802 |
+
"\n",
|
| 803 |
+
"Per-round flow when enabled:\n",
|
| 804 |
+
"```\n",
|
| 805 |
+
"fraud_llm.grpo_step(rung_idx)\n",
|
| 806 |
+
" -> build N prompts, all sharing the same match-summary\n",
|
| 807 |
+
" -> GRPO group of FRAUD_GRPO_NUM_GENERATIONS samples per prompt\n",
|
| 808 |
+
" -> reward each sample by pushing it as adversary θ + quick_defender_eval\n",
|
| 809 |
+
" -> after burst: greedy-decode best θ, push to env, sync into fraud_agent.theta\n",
|
| 810 |
+
"```\n",
|
| 811 |
+
"Downstream code (league snapshots, exploitability gap, eval) is identical —\n",
|
| 812 |
+
"the LLM-proposed θ flows through the SAME `fraud_agent.theta` channel that\n",
|
| 813 |
+
"ES used to write to."
|
| 814 |
]
|
| 815 |
},
|
| 816 |
{
|
|
|
|
| 822 |
"source": [
|
| 823 |
"from unsloth import FastLanguageModel\n",
|
| 824 |
"from datasets import Dataset\n",
|
| 825 |
+
"from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer\n",
|
| 826 |
"import hashlib, torch\n",
|
| 827 |
"\n",
|
| 828 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
|
|
|
| 831 |
" dtype=None,\n",
|
| 832 |
" load_in_4bit=LOAD_IN_4BIT,\n",
|
| 833 |
")\n",
|
| 834 |
+
"# Phi-3 uses fused projections (qkv_proj, gate_up_proj) — different module\n",
|
| 835 |
+
"# names than Qwen/Llama. We list both Phi-3 names and the standard names\n",
|
| 836 |
+
"# so the same cell works if MODEL_ID is later swapped back.\n",
|
| 837 |
+
"_PHI3_MODULES = ['qkv_proj', 'o_proj', 'gate_up_proj', 'down_proj']\n",
|
| 838 |
+
"_QWEN_MODULES = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj']\n",
|
| 839 |
+
"_target_modules = _PHI3_MODULES if 'phi-3' in MODEL_ID.lower() else _QWEN_MODULES\n",
|
| 840 |
+
"print(f'LoRA target_modules ({MODEL_ID}): {_target_modules}')\n",
|
| 841 |
"model = FastLanguageModel.get_peft_model(\n",
|
| 842 |
" model,\n",
|
| 843 |
" r=16,\n",
|
| 844 |
+
" target_modules=_target_modules,\n",
|
| 845 |
" lora_alpha=32,\n",
|
| 846 |
" lora_dropout=0.0,\n",
|
| 847 |
" bias='none',\n",
|
|
|
|
| 850 |
")\n",
|
| 851 |
"if tokenizer.pad_token is None:\n",
|
| 852 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 853 |
+
"# CRITICAL: left-truncate so if the prompt overflows, we drop the LEGEND\n",
|
| 854 |
+
"# at the front and keep the schema instruction at the END. Without this,\n",
|
| 855 |
+
"# right-truncation silently drops \"Return one action JSON...\" and the model\n",
|
| 856 |
+
"# emits prose -> parse_action falls back -> zero advantage in the GRPO group.\n",
|
| 857 |
+
"tokenizer.truncation_side = 'left'\n",
|
| 858 |
+
"\n",
|
| 859 |
+
"# ── Optional dual-LoRA fraud LLM ──────────────────────────────────────\n",
|
| 860 |
+
"# When USE_LLM_FRAUD=True we load a SECOND base-model + LoRA dedicated to\n",
|
| 861 |
+
"# the fraud agent. Same MODEL_ID, separate weights/adapter so the two\n",
|
| 862 |
+
"# policies don't interfere. The fraud LoRA is smaller (FRAUD_LORA_R) since\n",
|
| 863 |
+
"# the fraud action space is just a 3-float JSON.\n",
|
| 864 |
+
"fraud_model = None\n",
|
| 865 |
+
"fraud_tokenizer = None\n",
|
| 866 |
+
"if USE_LLM_FRAUD:\n",
|
| 867 |
+
" print(f'\\n[USE_LLM_FRAUD=True] loading SECOND base+LoRA for the fraud agent...')\n",
|
| 868 |
+
" fraud_model, fraud_tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 869 |
+
" model_name=MODEL_ID,\n",
|
| 870 |
+
" max_seq_length=MAX_SEQ_LEN,\n",
|
| 871 |
+
" dtype=None,\n",
|
| 872 |
+
" load_in_4bit=LOAD_IN_4BIT,\n",
|
| 873 |
+
" )\n",
|
| 874 |
+
" fraud_model = FastLanguageModel.get_peft_model(\n",
|
| 875 |
+
" fraud_model,\n",
|
| 876 |
+
" r=FRAUD_LORA_R,\n",
|
| 877 |
+
" target_modules=_target_modules,\n",
|
| 878 |
+
" lora_alpha=2 * FRAUD_LORA_R,\n",
|
| 879 |
+
" lora_dropout=0.0,\n",
|
| 880 |
+
" bias='none',\n",
|
| 881 |
+
" use_gradient_checkpointing='unsloth',\n",
|
| 882 |
+
" random_state=SEED + 1,\n",
|
| 883 |
+
" )\n",
|
| 884 |
+
" if fraud_tokenizer.pad_token is None:\n",
|
| 885 |
+
" fraud_tokenizer.pad_token = fraud_tokenizer.eos_token\n",
|
| 886 |
+
" fraud_tokenizer.truncation_side = 'left'\n",
|
| 887 |
+
" print(f' fraud-LLM ready (LoRA r={FRAUD_LORA_R}, separate from defender)')\n",
|
| 888 |
"\n",
|
| 889 |
"ds = Dataset.from_list([{'prompt': p} for p in prompts])\n",
|
| 890 |
"print(ds)\n",
|
| 891 |
"\n",
|
| 892 |
+
"# Token budgets (used by both SFT and GRPO below). Centralised in cell 6.\n",
|
| 893 |
+
"_DEF_MAX_PROMPT = DEF_MAX_PROMPT_TOKENS\n",
|
| 894 |
+
"_DEF_MAX_NEW = DEF_MAX_NEW_TOKENS\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"# ── Stage 1: SFT warm-start on heuristic-labeled actions ──────────────\n",
|
| 897 |
+
"# Without this, GRPO sees ~zero advantage between completions (all of them\n",
|
| 898 |
+
"# fail to emit valid JSON) and the loss collapses to ~1e-6 with a flat\n",
|
| 899 |
+
"# reward curve. SFT teaches the FORMAT + the basic risk→action prior so\n",
|
| 900 |
+
"# GRPO has actual variance to optimise.\n",
|
| 901 |
+
"\n",
|
| 902 |
+
"SFT_STEPS = 20 if QUICK_MODE else 80\n",
|
| 903 |
+
"SFT_LR = 2e-4\n",
|
| 904 |
+
"\n",
|
| 905 |
+
"def _heuristic_completion(obs):\n",
|
| 906 |
+
" \"\"\"Expert label = heuristic policy action, serialised as compact JSON.\"\"\"\n",
|
| 907 |
+
" a = heuristic_policy(obs)\n",
|
| 908 |
+
" return json.dumps(a)\n",
|
| 909 |
+
"\n",
|
| 910 |
+
"# Build (prompt, completion) pairs. SFTTrainer concatenates them and trains\n",
|
| 911 |
+
"# the LM to predict completion tokens given prompt.\n",
|
| 912 |
+
"sft_records = [\n",
|
| 913 |
+
" {'prompt': p, 'completion': _heuristic_completion(o)}\n",
|
| 914 |
+
" for p, o in zip(prompts, prompt_obs)\n",
|
| 915 |
+
"]\n",
|
| 916 |
+
"sft_ds = Dataset.from_list(sft_records)\n",
|
| 917 |
+
"print('SFT dataset:', sft_ds, '| sample completion:', sft_records[0]['completion'])\n",
|
| 918 |
+
"\n",
|
| 919 |
+
"sft_cfg = SFTConfig(\n",
|
| 920 |
+
" output_dir='outputs/theme4_sft_warmstart',\n",
|
| 921 |
+
" per_device_train_batch_size=2,\n",
|
| 922 |
+
" gradient_accumulation_steps=2,\n",
|
| 923 |
+
" max_steps=SFT_STEPS,\n",
|
| 924 |
+
" learning_rate=SFT_LR,\n",
|
| 925 |
+
" logging_steps=2,\n",
|
| 926 |
+
" save_strategy='no',\n",
|
| 927 |
+
" report_to=[],\n",
|
| 928 |
+
" # bf16 intentionally NOT set: T4 GPUs (the Colab default) don't support\n",
|
| 929 |
+
" # bf16 and Unsloth handles dtype internally for the 4-bit base + fp16\n",
|
| 930 |
+
" # LoRA. Letting the trainer auto-pick avoids \"bf16 unsupported\" crashes.\n",
|
| 931 |
+
" max_length=_DEF_MAX_PROMPT + _DEF_MAX_NEW + 32,\n",
|
| 932 |
+
" packing=False,\n",
|
| 933 |
+
" # Newer TRL defaults `padding_free=True`, which then refuses to enforce\n",
|
| 934 |
+
" # `max_length` unless packing is on. We don't want packing (it'd glue\n",
|
| 935 |
+
" # different (prompt, heuristic_completion) pairs together and confuse\n",
|
| 936 |
+
" # `completion_only_loss=True`), so disable padding-free explicitly.\n",
|
| 937 |
+
" padding_free=False,\n",
|
| 938 |
+
" completion_only_loss=True, # don't waste loss on prompt tokens\n",
|
| 939 |
+
")\n",
|
| 940 |
+
"sft_trainer = SFTTrainer(\n",
|
| 941 |
+
" model=model,\n",
|
| 942 |
+
" args=sft_cfg,\n",
|
| 943 |
+
" train_dataset=sft_ds,\n",
|
| 944 |
+
" processing_class=tokenizer,\n",
|
| 945 |
+
")\n",
|
| 946 |
+
"print(f'\\n=== SFT warm-start: {SFT_STEPS} steps on {len(sft_ds)} (prompt, heuristic_action) pairs ===')\n",
|
| 947 |
+
"sft_trainer.train()\n",
|
| 948 |
+
"sft_loss_history = [h.get('loss') for h in sft_trainer.state.log_history if 'loss' in h]\n",
|
| 949 |
+
"print('SFT done. loss curve:', sft_loss_history)\n",
|
| 950 |
+
"\n",
|
| 951 |
"# ── Reward fn: same-seed group + multi-step rollout ───────────────────\n",
|
| 952 |
"_REWARD_DEBUG = {'calls': 0}\n",
|
| 953 |
"\n",
|
|
|
|
| 961 |
" return str(comp)\n",
|
| 962 |
"\n",
|
| 963 |
"def _seed_for_prompt(prompt_text):\n",
|
| 964 |
+
" \"\"\"Look up the seed used to generate this prompt's obs (cell 12). When\n",
|
| 965 |
+
" found, env_reset_seeded(seed) reproduces the EXACT obs in the prompt, so\n",
|
| 966 |
+
" the reward is for the action-on-prompt's-obs (the only meaningful signal).\n",
|
| 967 |
+
"\n",
|
| 968 |
+
" Falls back to a hash for unseen prompts (e.g. evaluation), but during\n",
|
| 969 |
+
" GRPO training every prompt should hit the cache.\"\"\"\n",
|
| 970 |
+
" key = _obs_key(prompt_text or '')\n",
|
| 971 |
+
" s = PROMPT_TO_SEED.get(key)\n",
|
| 972 |
+
" if s is not None:\n",
|
| 973 |
+
" return int(s)\n",
|
| 974 |
+
" h = hashlib.md5((prompt_text or '').encode('utf-8')).hexdigest()\n",
|
| 975 |
" return int(h[:8], 16) & 0x7FFFFFFF\n",
|
| 976 |
"\n",
|
| 977 |
"def reward_fn(completions, prompts=None, **kwargs):\n",
|
| 978 |
+
" \"\"\"For each completion: parse action, score it on the PROMPT'S obs by\n",
|
| 979 |
+
" resetting the env to the cached seed for that prompt. All completions in\n",
|
| 980 |
+
" a GRPO group share the same prompt -> same seed -> same starting obs ->\n",
|
| 981 |
+
" only the action varies -> clean group-relative advantage.\n",
|
| 982 |
+
"\n",
|
| 983 |
+
" LEAGUE-AWARE: if the prompt was collected under a *past* rung (rehearsal\n",
|
| 984 |
+
" share), we re-apply that past θ to the env BEFORE the rollout so the\n",
|
| 985 |
+
" obs reproduces exactly. We then restore the global current adversary\n",
|
| 986 |
+
" after the batch (handled by the surrounding loop).\"\"\"\n",
|
| 987 |
" rewards = []\n",
|
| 988 |
+
" parsed_actions = []\n",
|
| 989 |
+
" n_cache_hit = 0\n",
|
| 990 |
+
" n_past_rehearsal = 0\n",
|
| 991 |
" prompts = prompts or [None] * len(completions)\n",
|
| 992 |
+
" last_theta_applied = None\n",
|
| 993 |
" for prompt_text, comp in zip(prompts, completions):\n",
|
| 994 |
" text = _extract_text(comp)\n",
|
| 995 |
" action = parse_action(text)\n",
|
| 996 |
+
" parsed_actions.append(action)\n",
|
| 997 |
+
" key = _obs_key(prompt_text or '')\n",
|
| 998 |
" seed = _seed_for_prompt(prompt_text or text)\n",
|
| 999 |
+
" if key in PROMPT_TO_SEED:\n",
|
| 1000 |
+
" n_cache_hit += 1\n",
|
| 1001 |
+
" # Re-apply the adversary the prompt was made under (only if it differs\n",
|
| 1002 |
+
" # from what we last applied — avoids spamming the env API).\n",
|
| 1003 |
+
" prompt_theta = PROMPT_TO_THETA.get(key)\n",
|
| 1004 |
+
" if prompt_theta is not None and prompt_theta != last_theta_applied:\n",
|
| 1005 |
+
" env_configure_adversary(**prompt_theta, strategy='mixed')\n",
|
| 1006 |
+
" last_theta_applied = prompt_theta\n",
|
| 1007 |
+
" if prompt_theta != _CURRENT_ROUND_THETA.get('theta'):\n",
|
| 1008 |
+
" n_past_rehearsal += 1\n",
|
| 1009 |
" try:\n",
|
| 1010 |
" r = rollout_reward(action, seed=seed, difficulty=DIFFICULTY,\n",
|
| 1011 |
" k=ROLLOUT_STEPS_PER_REWARD)\n",
|
|
|
|
| 1013 |
" print('reward_fn error:', repr(e))\n",
|
| 1014 |
" r = 0.0\n",
|
| 1015 |
" rewards.append(float(r))\n",
|
| 1016 |
+
" # Restore current round's adversary after the batch so ES + quick eval\n",
|
| 1017 |
+
" # next called sees the canonical state.\n",
|
| 1018 |
+
" cur = _CURRENT_ROUND_THETA.get('theta')\n",
|
| 1019 |
+
" if cur is not None and cur != last_theta_applied:\n",
|
| 1020 |
+
" env_configure_adversary(**cur, strategy='mixed')\n",
|
| 1021 |
" _REWARD_DEBUG['calls'] += 1\n",
|
| 1022 |
" if _REWARD_DEBUG['calls'] <= 3:\n",
|
| 1023 |
+
" n_unique_actions = len({tuple(sorted(a.items())) for a in parsed_actions})\n",
|
| 1024 |
+
" n_unique_rewards = len({round(r, 4) for r in rewards})\n",
|
| 1025 |
+
" print(f\"[reward_fn batch {_REWARD_DEBUG['calls']}] \"\n",
|
| 1026 |
+
" f\"cache_hits={n_cache_hit}/{len(completions)} \"\n",
|
| 1027 |
+
" f\"past_rehearsal_reapplies={n_past_rehearsal} \"\n",
|
| 1028 |
+
" f\"unique_actions={n_unique_actions} \"\n",
|
| 1029 |
+
" f\"unique_rewards={n_unique_rewards} \"\n",
|
| 1030 |
+
" f\"reward_std={float(np.std(rewards)):.4f} \"\n",
|
| 1031 |
+
" f\"sample={rewards[:6]}\")\n",
|
| 1032 |
" return rewards\n",
|
| 1033 |
"\n",
|
| 1034 |
+
"# Tracks the round's \"current\" θ so reward_fn can restore it after a\n",
|
| 1035 |
+
"# rehearsal-sample reapply. Populated by the loop below.\n",
|
| 1036 |
+
"_CURRENT_ROUND_THETA = {'theta': None}\n",
|
| 1037 |
+
"\n",
|
| 1038 |
"# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
|
| 1039 |
+
"# Token budgets are big enough to (a) NOT truncate the schema instruction at\n",
|
| 1040 |
+
"# the end of the prompt and (b) safely fit a JSON action even if the model\n",
|
| 1041 |
+
"# writes a short prose prefix. With tokenizer.truncation_side='left' set\n",
|
| 1042 |
+
"# above, any overflow drops the legend at the front (lowest-value tokens),\n",
|
| 1043 |
+
"# never the schema instruction at the end.\n",
|
| 1044 |
"\n",
|
| 1045 |
"@torch.no_grad()\n",
|
| 1046 |
"def _defender_action(obs):\n",
|
|
|
|
| 1057 |
" FastLanguageModel.for_training(model)\n",
|
| 1058 |
" return parse_action(text)\n",
|
| 1059 |
"\n",
|
| 1060 |
+
"# ── Post-SFT sanity: the warm-started model should now agree with the\n",
|
| 1061 |
+
"# heuristic on most prompts. If it doesn't, GRPO will still help, but\n",
|
| 1062 |
+
"# this is the cheapest signal that SFT actually moved the policy.\n",
|
| 1063 |
+
"_warm_match = 0\n",
|
| 1064 |
+
"_warm_n = min(8, len(prompt_obs))\n",
|
| 1065 |
+
"for _o in prompt_obs[:_warm_n]:\n",
|
| 1066 |
+
" _a_model = _defender_action(_o)\n",
|
| 1067 |
+
" _a_heur = heuristic_policy(_o)\n",
|
| 1068 |
+
" if _a_model == _a_heur:\n",
|
| 1069 |
+
" _warm_match += 1\n",
|
| 1070 |
+
"print(f' SFT sanity: model matches heuristic on {_warm_match}/{_warm_n} sample obs')\n",
|
| 1071 |
+
"\n",
|
| 1072 |
"# ── GRPO config (per-round) ───────────────────────────────────────────\n",
|
| 1073 |
"def _make_grpo_cfg(max_steps):\n",
|
| 1074 |
" return GRPOConfig(\n",
|
|
|
|
| 1080 |
" gradient_accumulation_steps=2,\n",
|
| 1081 |
" max_steps=int(max_steps),\n",
|
| 1082 |
" logging_steps=1,\n",
|
| 1083 |
+
" learning_rate=5e-6, # lower than 1e-5 to keep close to SFT prior\n",
|
| 1084 |
" save_strategy='no',\n",
|
| 1085 |
" report_to=[],\n",
|
| 1086 |
+
" # bf16 intentionally NOT set — T4 has no bf16 support; Unsloth picks\n",
|
| 1087 |
+
" # the right dtype automatically based on the loaded 4-bit base model.\n",
|
| 1088 |
+
" temperature=1.1, # slight bump so post-SFT logits explore\n",
|
| 1089 |
+
" beta=0.04, # stronger KL: don't drift from SFT'd policy\n",
|
| 1090 |
" )\n",
|
| 1091 |
"\n",
|
| 1092 |
"# ── Co-training loop ──────────────────────────────────────────────────\n",
|
|
|
|
| 1096 |
"fraud_theta_history = [dict(fraud_agent.theta)]\n",
|
| 1097 |
"loss_history_all = []\n",
|
| 1098 |
"reward_log_all = []\n",
|
| 1099 |
+
"ladder_round_rung = [] # which ladder rung each round trained against\n",
|
| 1100 |
"\n",
|
| 1101 |
"# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
|
| 1102 |
"# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
|
|
|
|
| 1113 |
" obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
|
| 1114 |
" return float(np.mean(rs)) if rs else 0.0\n",
|
| 1115 |
"\n",
|
| 1116 |
+
"def _refresh_prompts_for_round(rnd_idx, current_theta):\n",
|
| 1117 |
+
" \"\"\"FIX B + League rehearsal — re-collect prompts so prompt-obs and\n",
|
| 1118 |
+
" reward-obs match exactly inside this round's GRPO.\n",
|
| 1119 |
+
"\n",
|
| 1120 |
+
" LADDER + LEAGUE TWIST: a fraction `LEAGUE_PAST_SAMPLE_PROB` of prompts\n",
|
| 1121 |
+
" are collected under a *sampled past rung* instead of the current rung.\n",
|
| 1122 |
+
" Crucially, the env's adversary is restored to the CURRENT rung after\n",
|
| 1123 |
+
" refresh — but the prompts collected under the past rung carry an obs\n",
|
| 1124 |
+
" that wouldn't exist under the current adversary. To keep alignment\n",
|
| 1125 |
+
" perfect, we ONLY use the past rung for prompts whose REWARD will also\n",
|
| 1126 |
+
" be computed under that rung. We accomplish this by:\n",
|
| 1127 |
+
" (a) splitting the prompt set into 'current' and 'past' shards,\n",
|
| 1128 |
+
" (b) computing all 'current' prompts first, then ES-time-temporarily\n",
|
| 1129 |
+
" applying the past rung to compute 'past' prompts,\n",
|
| 1130 |
+
" (c) restoring the current rung at the end, and\n",
|
| 1131 |
+
" (d) tagging each prompt's seed with the adversary it was made under,\n",
|
| 1132 |
+
" so reward_fn can re-apply that adversary before scoring.\n",
|
| 1133 |
+
"\n",
|
| 1134 |
+
" For QUICK_MODE (3 rounds) the past pool only fills from round 2 onward,\n",
|
| 1135 |
+
" so round 0 always uses 100% current rung.\n",
|
| 1136 |
+
"\n",
|
| 1137 |
+
" Returns: (Dataset, prompts_list, obs_list).\n",
|
| 1138 |
+
" \"\"\"\n",
|
| 1139 |
+
" base = PROMPT_BASE_SEED + rnd_idx * PROMPT_DATASET_SIZE * 13\n",
|
| 1140 |
+
"\n",
|
| 1141 |
+
" # Decide how many prompts come from a past rung (rehearsal share).\n",
|
| 1142 |
+
" n_past = 0\n",
|
| 1143 |
+
" past_theta = None\n",
|
| 1144 |
+
" if len(league) >= 1:\n",
|
| 1145 |
+
" past_theta = league.sample_past()\n",
|
| 1146 |
+
" if past_theta is not None:\n",
|
| 1147 |
+
" n_past = int(round(PROMPT_DATASET_SIZE * LEAGUE_PAST_SAMPLE_PROB))\n",
|
| 1148 |
+
" n_current = PROMPT_DATASET_SIZE - n_past\n",
|
| 1149 |
+
"\n",
|
| 1150 |
+
" # Phase 1 — current rung prompts\n",
|
| 1151 |
+
" env_configure_adversary(**current_theta, strategy='mixed')\n",
|
| 1152 |
+
" cur_prompts, cur_obs, cur_seeds = collect_prompts(n=n_current, base_seed=base)\n",
|
| 1153 |
+
" cur_theta_per_seed = {s: dict(current_theta) for s in cur_seeds}\n",
|
| 1154 |
+
"\n",
|
| 1155 |
+
" # Phase 2 — past rung rehearsal prompts (if any)\n",
|
| 1156 |
+
" past_prompts, past_obs, past_seeds = [], [], []\n",
|
| 1157 |
+
" past_theta_per_seed = {}\n",
|
| 1158 |
+
" if n_past > 0 and past_theta is not None:\n",
|
| 1159 |
+
" env_configure_adversary(**past_theta, strategy='mixed')\n",
|
| 1160 |
+
" past_prompts, past_obs, past_seeds = collect_prompts(\n",
|
| 1161 |
+
" n=n_past, base_seed=base + 7919 # disjoint sub-range\n",
|
| 1162 |
+
" )\n",
|
| 1163 |
+
" past_theta_per_seed = {s: dict(past_theta) for s in past_seeds}\n",
|
| 1164 |
+
"\n",
|
| 1165 |
+
" # Restore current rung as the env's \"default\" — reward_fn will re-apply\n",
|
| 1166 |
+
" # the per-seed θ before each rollout (see PROMPT_TO_THETA below).\n",
|
| 1167 |
+
" env_configure_adversary(**current_theta, strategy='mixed')\n",
|
| 1168 |
+
"\n",
|
| 1169 |
+
" # Combine\n",
|
| 1170 |
+
" new_prompts = cur_prompts + past_prompts\n",
|
| 1171 |
+
" new_obs = cur_obs + past_obs\n",
|
| 1172 |
+
" new_seeds = cur_seeds + past_seeds\n",
|
| 1173 |
+
" new_theta_per_seed = {**cur_theta_per_seed, **past_theta_per_seed}\n",
|
| 1174 |
+
"\n",
|
| 1175 |
+
" PROMPT_TO_SEED.clear()\n",
|
| 1176 |
+
" PROMPT_TO_SEED.update({_obs_key(p): s for p, s in zip(new_prompts, new_seeds)})\n",
|
| 1177 |
+
" PROMPT_TO_OBS.clear()\n",
|
| 1178 |
+
" PROMPT_TO_OBS.update({_obs_key(p): o for p, o in zip(new_prompts, new_obs)})\n",
|
| 1179 |
+
" PROMPT_TO_THETA.clear()\n",
|
| 1180 |
+
" PROMPT_TO_THETA.update({_obs_key(p): new_theta_per_seed[s]\n",
|
| 1181 |
+
" for p, s in zip(new_prompts, new_seeds)})\n",
|
| 1182 |
+
"\n",
|
| 1183 |
+
" print(f' [FIX B + league] refreshed {len(new_prompts)} prompts: '\n",
|
| 1184 |
+
" f'{n_current} current rung + {n_past} past rung (rehearsal)')\n",
|
| 1185 |
+
" return Dataset.from_list([{'prompt': p} for p in new_prompts]), new_prompts, new_obs\n",
|
| 1186 |
+
"\n",
|
| 1187 |
+
"# ── Per-prompt theta lookup so reward_fn can re-apply the adversary the\n",
|
| 1188 |
+
"# prompt was made under (essential for league rehearsal to stay aligned).\n",
|
| 1189 |
+
"PROMPT_TO_THETA = {}\n",
|
| 1190 |
+
"\n",
|
| 1191 |
+
"def _rung_for_round(rnd_idx):\n",
|
| 1192 |
+
" \"\"\"Distribute ladder rungs evenly across rounds. With N_ROUNDS=3 + 3 rungs\n",
|
| 1193 |
+
" we get rounds [0,1,2] -> rungs [0,1,2]. With N_ROUNDS=6 + 3 rungs we get\n",
|
| 1194 |
+
" rounds [0,1,2,3,4,5] -> rungs [0,0,1,1,2,2].\"\"\"\n",
|
| 1195 |
+
" return min(rnd_idx * len(LADDER_RUNGS) // max(N_ROUNDS, 1), len(LADDER_RUNGS) - 1)\n",
|
| 1196 |
+
"\n",
|
| 1197 |
+
"# ── OPTIONAL: dual-LoRA fraud LLM policy ─────────────────────────────\n",
|
| 1198 |
+
"# When USE_LLM_FRAUD=True, this replaces FraudPolicy.es_step inside the\n",
|
| 1199 |
+
"# co-training loop. It is a SECOND LoRA on the same Phi-3 base, trained\n",
|
| 1200 |
+
"# with TRL GRPO to OUTPUT adversary-parameter JSON. Reward = 1 - defender_reward\n",
|
| 1201 |
+
"# under the proposed θ, so the GRPO group-relative advantage rewards the\n",
|
| 1202 |
+
"# fraud LLM for proposing thetas the current defender is weakest against.\n",
|
| 1203 |
+
"#\n",
|
| 1204 |
+
"# Why this is the right structural upgrade (vs. e.g. fraud LLM emitting\n",
|
| 1205 |
+
"# raw transaction JSON): it reuses the existing /configure_adversary +\n",
|
| 1206 |
+
"# quick_defender_eval pipeline, so we don't need any new env endpoints —\n",
|
| 1207 |
+
"# the fraud LLM's \"action\" is exactly the same dict that ES manipulates.\n",
|
| 1208 |
+
"\n",
|
| 1209 |
+
"_FRAUD_KEYS = ('intensity', 'noise_boost', 'pattern_rate')\n",
|
| 1210 |
+
"\n",
|
| 1211 |
+
"def _fraud_summary_text(rung_idx, current_theta, last_def_score):\n",
|
| 1212 |
+
" \"\"\"Compact, model-friendly summary of the current matchup that the fraud\n",
|
| 1213 |
+
" LLM conditions on. Kept short so the prompt stays under FRAUD_MAX_PROMPT_TOKENS.\"\"\"\n",
|
| 1214 |
+
" return (\n",
|
| 1215 |
+
" f'rung={rung_idx}'\n",
|
| 1216 |
+
" f' | current_theta={ {k: round(current_theta[k], 3) for k in _FRAUD_KEYS} }'\n",
|
| 1217 |
+
" f' | last_defender_reward={last_def_score:.3f}'\n",
|
| 1218 |
+
" f' | bounds={ {k: list(FRAUD_PARAM_BOUNDS[k]) for k in _FRAUD_KEYS} }'\n",
|
| 1219 |
+
" )\n",
|
| 1220 |
+
"\n",
|
| 1221 |
+
"def make_fraud_prompt(summary):\n",
|
| 1222 |
+
" return (\n",
|
| 1223 |
+
" 'You design adversary parameters for a payments env. The defender LLM\\n'\n",
|
| 1224 |
+
" 'is an LLM-based fraud detector. Your goal: pick (intensity, noise_boost,\\n'\n",
|
| 1225 |
+
" 'pattern_rate) so the defender\\'s reward is MINIMISED while staying inside\\n'\n",
|
| 1226 |
+
" 'the bounds. Higher intensity = harder fraud, higher noise_boost = stealthier\\n'\n",
|
| 1227 |
+
" 'risk score, higher pattern_rate = more bursty attacks.\\n'\n",
|
| 1228 |
+
" f'Match summary: {summary}\\n'\n",
|
| 1229 |
+
" 'Return ONE JSON: {\"intensity\": <float>, \"noise_boost\": <float>, \"pattern_rate\": <float>}.'\n",
|
| 1230 |
+
" )\n",
|
| 1231 |
+
"\n",
|
| 1232 |
+
"_FRAUD_JSON_RE = re.compile(r'\\{[^{}]*\\}')\n",
|
| 1233 |
+
"\n",
|
| 1234 |
+
"def parse_fraud_theta(text, default_theta):\n",
|
| 1235 |
+
" \"\"\"Extract {intensity, noise_boost, pattern_rate} JSON, fall back to the\n",
|
| 1236 |
+
" given default + clip to bounds. Same defensive pattern as parse_action.\"\"\"\n",
|
| 1237 |
+
" m = _FRAUD_JSON_RE.search(text or '')\n",
|
| 1238 |
+
" if not m:\n",
|
| 1239 |
+
" return _clip_theta(dict(default_theta))\n",
|
| 1240 |
+
" try:\n",
|
| 1241 |
+
" raw = json.loads(m.group(0))\n",
|
| 1242 |
+
" out = dict(default_theta)\n",
|
| 1243 |
+
" for k in _FRAUD_KEYS:\n",
|
| 1244 |
+
" if k in raw:\n",
|
| 1245 |
+
" out[k] = float(raw[k])\n",
|
| 1246 |
+
" return _clip_theta(out)\n",
|
| 1247 |
+
" except Exception:\n",
|
| 1248 |
+
" return _clip_theta(dict(default_theta))\n",
|
| 1249 |
+
"\n",
|
| 1250 |
+
"class FraudLLMPolicy:\n",
|
| 1251 |
+
" \"\"\"Dual-LoRA fraud agent: an LLM that proposes adversary θ via GRPO.\n",
|
| 1252 |
+
" Replaces FraudPolicy.es_step when USE_LLM_FRAUD=True.\"\"\"\n",
|
| 1253 |
+
" def __init__(self, fmodel, ftokenizer, defender_fn, current_theta_fn):\n",
|
| 1254 |
+
" self.model = fmodel\n",
|
| 1255 |
+
" self.tokenizer = ftokenizer\n",
|
| 1256 |
+
" self.defender_fn = defender_fn\n",
|
| 1257 |
+
" self.current_theta_fn = current_theta_fn # ()->dict, latest θ\n",
|
| 1258 |
+
" self.last_def_score = 0.5\n",
|
| 1259 |
+
" self.loss_history = []\n",
|
| 1260 |
+
" self.reward_history = []\n",
|
| 1261 |
+
" self.theta_history = []\n",
|
| 1262 |
+
"\n",
|
| 1263 |
+
" @torch.no_grad()\n",
|
| 1264 |
+
" def _generate_one(self, summary):\n",
|
| 1265 |
+
" FastLanguageModel.for_inference(self.model)\n",
|
| 1266 |
+
" device = next(self.model.parameters()).device\n",
|
| 1267 |
+
" prompt = make_fraud_prompt(summary)\n",
|
| 1268 |
+
" inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True,\n",
|
| 1269 |
+
" max_length=FRAUD_MAX_PROMPT_TOKENS).to(device)\n",
|
| 1270 |
+
" out = self.model.generate(\n",
|
| 1271 |
+
" **inputs, max_new_tokens=FRAUD_MAX_NEW_TOKENS, do_sample=False,\n",
|
| 1272 |
+
" pad_token_id=self.tokenizer.pad_token_id,\n",
|
| 1273 |
+
" )\n",
|
| 1274 |
+
" text = self.tokenizer.decode(out[0][inputs['input_ids'].shape[1]:],\n",
|
| 1275 |
+
" skip_special_tokens=True)\n",
|
| 1276 |
+
" FastLanguageModel.for_training(self.model)\n",
|
| 1277 |
+
" return parse_fraud_theta(text, self.current_theta_fn())\n",
|
| 1278 |
+
"\n",
|
| 1279 |
+
" def grpo_step(self, rung_idx):\n",
|
| 1280 |
+
" \"\"\"One GRPO burst: build a tiny prompt set conditioned on the current\n",
|
| 1281 |
+
" match summary, train fraud LoRA to output θ with reward = 1 - defender_reward.\"\"\"\n",
|
| 1282 |
+
" cur_theta = self.current_theta_fn()\n",
|
| 1283 |
+
" # All prompts in the burst share the same summary (it doesn't change\n",
|
| 1284 |
+
" # within a single ES-replacement step). num_generations supplies the\n",
|
| 1285 |
+
" # group-relative variance via sampling, exactly like defender GRPO.\n",
|
| 1286 |
+
" summary = _fraud_summary_text(rung_idx, cur_theta, self.last_def_score)\n",
|
| 1287 |
+
" prompt = make_fraud_prompt(summary)\n",
|
| 1288 |
+
" ds_fraud = Dataset.from_list(\n",
|
| 1289 |
+
" [{'prompt': prompt} for _ in range(FRAUD_PROMPT_DATASET_SIZE)]\n",
|
| 1290 |
+
" )\n",
|
| 1291 |
+
"\n",
|
| 1292 |
+
" def fraud_reward_fn(completions, prompts=None, **_):\n",
|
| 1293 |
+
" rewards = []\n",
|
| 1294 |
+
" for comp in completions:\n",
|
| 1295 |
+
" text = (comp if isinstance(comp, str)\n",
|
| 1296 |
+
" else (comp[0].get('content','') if isinstance(comp, list)\n",
|
| 1297 |
+
" else comp.get('content','')))\n",
|
| 1298 |
+
" proposed = parse_fraud_theta(text, cur_theta)\n",
|
| 1299 |
+
" # Push proposal to env, measure defender reward under it.\n",
|
| 1300 |
+
" env_configure_adversary(**proposed, strategy='mixed')\n",
|
| 1301 |
+
" def_score = quick_defender_eval()\n",
|
| 1302 |
+
" rewards.append(float(1.0 - def_score)) # fraud wants low def_reward\n",
|
| 1303 |
+
" # Restore current θ so the OUTER loop's next call sees canonical state.\n",
|
| 1304 |
+
" env_configure_adversary(**cur_theta, strategy='mixed')\n",
|
| 1305 |
+
" return rewards\n",
|
| 1306 |
+
"\n",
|
| 1307 |
+
" cfg = GRPOConfig(\n",
|
| 1308 |
+
" output_dir='outputs/theme4_fraud_grpo',\n",
|
| 1309 |
+
" num_generations=FRAUD_GRPO_NUM_GENERATIONS,\n",
|
| 1310 |
+
" max_prompt_length=FRAUD_MAX_PROMPT_TOKENS,\n",
|
| 1311 |
+
" max_completion_length=FRAUD_MAX_NEW_TOKENS,\n",
|
| 1312 |
+
" per_device_train_batch_size=1,\n",
|
| 1313 |
+
" gradient_accumulation_steps=2,\n",
|
| 1314 |
+
" max_steps=int(FRAUD_GRPO_STEPS_PER_ROUND),\n",
|
| 1315 |
+
" logging_steps=1,\n",
|
| 1316 |
+
" learning_rate=5e-6,\n",
|
| 1317 |
+
" save_strategy='no',\n",
|
| 1318 |
+
" report_to=[],\n",
|
| 1319 |
+
" temperature=1.1,\n",
|
| 1320 |
+
" beta=0.04,\n",
|
| 1321 |
+
" )\n",
|
| 1322 |
+
" trainer = GRPOTrainer(\n",
|
| 1323 |
+
" model=self.model, args=cfg, train_dataset=ds_fraud,\n",
|
| 1324 |
+
" processing_class=self.tokenizer, reward_funcs=[fraud_reward_fn],\n",
|
| 1325 |
+
" )\n",
|
| 1326 |
+
" trainer.train()\n",
|
| 1327 |
+
" self.loss_history.extend(\n",
|
| 1328 |
+
" [h.get('loss') for h in trainer.state.log_history if 'loss' in h]\n",
|
| 1329 |
+
" )\n",
|
| 1330 |
+
" self.reward_history.extend(\n",
|
| 1331 |
+
" [h.get('reward') for h in trainer.state.log_history if 'reward' in h]\n",
|
| 1332 |
+
" )\n",
|
| 1333 |
+
"\n",
|
| 1334 |
+
" # Greedy generation = the LoRA's \"best guess\" θ after this burst.\n",
|
| 1335 |
+
" new_theta = self._generate_one(summary)\n",
|
| 1336 |
+
" self.theta_history.append(dict(new_theta))\n",
|
| 1337 |
+
" env_configure_adversary(**new_theta, strategy='mixed')\n",
|
| 1338 |
+
" # Refresh last-defender-score under the chosen θ (used in the NEXT\n",
|
| 1339 |
+
" # round's summary) so the fraud LLM gets a calibrated signal.\n",
|
| 1340 |
+
" self.last_def_score = float(quick_defender_eval())\n",
|
| 1341 |
+
" return {'theta': new_theta, 'def_reward_under_new_theta': self.last_def_score}\n",
|
| 1342 |
+
"\n",
|
| 1343 |
+
"# Instantiate fraud LLM policy ONCE if enabled. Defender_fn is set later\n",
|
| 1344 |
+
"# (closures capture the latest defender LoRA each call automatically).\n",
|
| 1345 |
+
"fraud_llm = None\n",
|
| 1346 |
+
"if USE_LLM_FRAUD and fraud_model is not None:\n",
|
| 1347 |
+
" fraud_llm = FraudLLMPolicy(\n",
|
| 1348 |
+
" fmodel=fraud_model,\n",
|
| 1349 |
+
" ftokenizer=fraud_tokenizer,\n",
|
| 1350 |
+
" defender_fn=_defender_action,\n",
|
| 1351 |
+
" current_theta_fn=lambda: dict(fraud_agent.theta),\n",
|
| 1352 |
+
" )\n",
|
| 1353 |
+
" print(f'[USE_LLM_FRAUD] FraudLLMPolicy ready '\n",
|
| 1354 |
+
" f'(GRPO steps/round={FRAUD_GRPO_STEPS_PER_ROUND}, '\n",
|
| 1355 |
+
" f'num_generations={FRAUD_GRPO_NUM_GENERATIONS})')\n",
|
| 1356 |
"\n",
|
| 1357 |
"for rnd in range(N_ROUNDS):\n",
|
| 1358 |
+
" rung_idx = _rung_for_round(rnd)\n",
|
| 1359 |
+
" rung_anchor = LADDER_RUNGS[rung_idx]\n",
|
| 1360 |
+
" ladder_round_rung.append(rung_idx)\n",
|
| 1361 |
+
" print(f'\\n=== Round {rnd+1}/{N_ROUNDS} | LADDER RUNG {rung_idx} ({rung_anchor}) ===')\n",
|
| 1362 |
+
"\n",
|
| 1363 |
+
" # Anchor the fraud agent at this rung's defaults at the START of the round\n",
|
| 1364 |
+
" # (only on rung CHANGE — within a rung, ES keeps drifting locally).\n",
|
| 1365 |
+
" if rnd == 0 or rung_idx != _rung_for_round(rnd - 1):\n",
|
| 1366 |
+
" fraud_agent.theta = dict(rung_anchor)\n",
|
| 1367 |
+
" fraud_agent.history.append(dict(fraud_agent.theta))\n",
|
| 1368 |
+
" fraud_theta_history.append(dict(fraud_agent.theta))\n",
|
| 1369 |
+
" print(f' ladder anchor applied: θ <- {fraud_agent.theta}')\n",
|
| 1370 |
+
" fraud_agent.apply()\n",
|
| 1371 |
+
" print(f' current fraud θ: {fraud_agent.theta}')\n",
|
| 1372 |
+
"\n",
|
| 1373 |
+
" # Track current-round θ so reward_fn knows what to restore between\n",
|
| 1374 |
+
" # rehearsal-sample reapplies.\n",
|
| 1375 |
+
" _CURRENT_ROUND_THETA['theta'] = dict(fraud_agent.theta)\n",
|
| 1376 |
+
"\n",
|
| 1377 |
+
" # FIX B + LEAGUE rehearsal — refresh prompts under the CURRENT adversary\n",
|
| 1378 |
+
" # (and a `LEAGUE_PAST_SAMPLE_PROB` share under a sampled past rung, with\n",
|
| 1379 |
+
" # per-prompt θ recorded so reward_fn can re-apply it correctly).\n",
|
| 1380 |
+
" ds_round, prompts_round, prompt_obs_round = _refresh_prompts_for_round(\n",
|
| 1381 |
+
" rnd, current_theta=fraud_agent.theta\n",
|
| 1382 |
+
" )\n",
|
| 1383 |
"\n",
|
| 1384 |
+
" # Phase A: defender GRPO on this round's freshly-aligned prompts.\n",
|
| 1385 |
" cfg = _make_grpo_cfg(max_steps=GRPO_STEPS_PER_ROUND)\n",
|
| 1386 |
" trainer = GRPOTrainer(\n",
|
| 1387 |
+
" model=model, args=cfg, train_dataset=ds_round,\n",
|
| 1388 |
" processing_class=tokenizer, reward_funcs=[reward_fn],\n",
|
| 1389 |
" )\n",
|
| 1390 |
" trainer.train()\n",
|
|
|
|
| 1393 |
" loss_history_all.extend(rnd_loss)\n",
|
| 1394 |
" reward_log_all.extend(rnd_rew)\n",
|
| 1395 |
"\n",
|
| 1396 |
+
" # Make sure env is back at current rung after GRPO before quick_eval.\n",
|
| 1397 |
+
" fraud_agent.apply()\n",
|
| 1398 |
" def_score = quick_defender_eval()\n",
|
| 1399 |
" defender_round_rewards.append(def_score)\n",
|
| 1400 |
" print(f' defender mean reward (round {rnd+1}): {def_score:.4f}')\n",
|
| 1401 |
"\n",
|
| 1402 |
+
" # Snapshot settled fraud-θ at this rung into the league (used by next\n",
|
| 1403 |
+
" # round's prompt rehearsal share).\n",
|
| 1404 |
+
" league.add(name=f'round{rnd+1}_rung{rung_idx}', theta=fraud_agent.theta)\n",
|
| 1405 |
+
" print(f' league snapshot taken: now {len(league)} rung(s) in pool')\n",
|
| 1406 |
+
"\n",
|
| 1407 |
+
" # Phase B: fraud update vs current defender.\n",
|
| 1408 |
+
" # USE_LLM_FRAUD=False (default) -> parametric ES on FraudPolicy\n",
|
| 1409 |
+
" # USE_LLM_FRAUD=True -> GRPO on the fraud LoRA (FraudLLMPolicy)\n",
|
| 1410 |
+
" # In both cases the resulting θ is pushed to the env via /configure_adversary\n",
|
| 1411 |
+
" # and `fraud_agent.theta` is kept in sync so downstream code (snapshots,\n",
|
| 1412 |
+
" # exploitability gap, eval) remains identical.\n",
|
| 1413 |
+
" if rnd < N_ROUNDS - 1:\n",
|
| 1414 |
" round_fraud_fits = []\n",
|
| 1415 |
+
" if USE_LLM_FRAUD and fraud_llm is not None:\n",
|
| 1416 |
+
" # Fraud LLM does ONE GRPO burst per round (FRAUD_GRPO_STEPS_PER_ROUND\n",
|
| 1417 |
+
" # steps inside it). Mirror θ back into fraud_agent so later code\n",
|
| 1418 |
+
" # (which still queries fraud_agent.theta) sees the new value.\n",
|
| 1419 |
+
" print(f' [USE_LLM_FRAUD] fraud LoRA GRPO step...')\n",
|
| 1420 |
+
" info = fraud_llm.grpo_step(rung_idx=rung_idx)\n",
|
| 1421 |
+
" new_theta = info['theta']\n",
|
| 1422 |
+
" fraud_agent.theta = dict(new_theta)\n",
|
| 1423 |
+
" fraud_agent.history.append(dict(fraud_agent.theta))\n",
|
| 1424 |
+
" round_fraud_fits.append(1.0 - info['def_reward_under_new_theta'])\n",
|
| 1425 |
+
" print(f' proposed θ={new_theta} | def_reward={info[\"def_reward_under_new_theta\"]:.3f}')\n",
|
| 1426 |
+
" else:\n",
|
| 1427 |
+
" for es in range(ES_STEPS_PER_ROUND):\n",
|
| 1428 |
+
" info = fraud_agent.es_step(_defender_action)\n",
|
| 1429 |
+
" round_fraud_fits.append(info['mean_fraud_fitness'])\n",
|
| 1430 |
+
" print(f' ES step {es+1}/{ES_STEPS_PER_ROUND}: mean_fitness={info[\"mean_fraud_fitness\"]:.3f}'\n",
|
| 1431 |
+
" f' best={info[\"best_fraud_fitness\"]:.3f} theta={info[\"theta\"]}')\n",
|
| 1432 |
" fraud_round_fitness.append(float(np.mean(round_fraud_fits)) if round_fraud_fits else 0.0)\n",
|
| 1433 |
" fraud_theta_history.append(dict(fraud_agent.theta))\n",
|
| 1434 |
"\n",
|
| 1435 |
" # Exploitability gap: how much WORSE the defender does against trained\n",
|
| 1436 |
+
" # fraud vs. against neutral fraud.\n",
|
| 1437 |
" env_configure_adversary(intensity=1.0, noise_boost=0.05, pattern_rate=0.2, strategy='mixed')\n",
|
| 1438 |
" baseline_def = quick_defender_eval()\n",
|
| 1439 |
+
" fraud_agent.apply()\n",
|
| 1440 |
" adv_def = quick_defender_eval()\n",
|
| 1441 |
" gap = float(baseline_def - adv_def)\n",
|
| 1442 |
" exploitability_log.append(gap)\n",
|
| 1443 |
" print(f' exploitability gap: baseline_def={baseline_def:.3f} vs adv_def={adv_def:.3f} -> gap={gap:.3f}')\n",
|
| 1444 |
"\n",
|
| 1445 |
+
"# ── Final league robustness telemetry ────────────────────────────────\n",
|
| 1446 |
+
"# Measure the trained defender against EVERY rung that was snapshotted.\n",
|
| 1447 |
+
"# A robust policy (good ladder-curriculum) scores well across rungs;\n",
|
| 1448 |
+
"# an over-fit one only scores well on the last. This is plotted in cell 22.\n",
|
| 1449 |
+
"print('\\n[league] measuring trained defender vs each league rung...')\n",
|
| 1450 |
+
"league_eval_rewards = []\n",
|
| 1451 |
+
"for rung in league.rungs:\n",
|
| 1452 |
+
" env_configure_adversary(**rung['theta'], strategy='mixed')\n",
|
| 1453 |
+
" score = quick_defender_eval()\n",
|
| 1454 |
+
" league_eval_rewards.append({'name': rung['name'], 'theta': rung['theta'],\n",
|
| 1455 |
+
" 'defender_reward': float(score)})\n",
|
| 1456 |
+
" print(f\" {rung['name']}: defender_reward={score:.3f} θ={rung['theta']}\")\n",
|
| 1457 |
+
"\n",
|
| 1458 |
+
"# Restore co-evolved fraud at the end so cell 20's trained_eval starts there.\n",
|
| 1459 |
+
"fraud_agent.apply()\n",
|
| 1460 |
+
"\n",
|
| 1461 |
"print('\\nCo-training finished.')\n",
|
| 1462 |
+
"print(' ladder rung schedule :', ladder_round_rung)\n",
|
| 1463 |
+
"print(' league pool size :', len(league),\n",
|
| 1464 |
+
" '|', [r['name'] for r in league.rungs])\n",
|
| 1465 |
"print(' defender_round_rewards:', defender_round_rewards)\n",
|
| 1466 |
+
"print(' fraud_round_fitness :', fraud_round_fitness)\n",
|
| 1467 |
+
"print(' exploitability_log :', exploitability_log)\n",
|
| 1468 |
"\n",
|
| 1469 |
"# Aliases for downstream cells\n",
|
| 1470 |
"loss_history = loss_history_all\n",
|
|
|
|
| 1534 |
"source": [
|
| 1535 |
"import matplotlib.pyplot as plt\n",
|
| 1536 |
"\n",
|
| 1537 |
+
"# 0. SFT warm-start loss\n",
|
| 1538 |
+
"if sft_loss_history:\n",
|
| 1539 |
+
" plt.figure(figsize=(8,4))\n",
|
| 1540 |
+
" plt.plot(sft_loss_history, marker='o', color='#a48', label='SFT loss')\n",
|
| 1541 |
+
" plt.xlabel('Logging step')\n",
|
| 1542 |
+
" plt.ylabel('Loss')\n",
|
| 1543 |
+
" plt.title('Stage 1 — SFT warm-start (heuristic imitation)')\n",
|
| 1544 |
+
" plt.legend()\n",
|
| 1545 |
+
" plt.tight_layout()\n",
|
| 1546 |
+
" plt.savefig('artifacts/sft_loss_curve.png', dpi=140)\n",
|
| 1547 |
+
" plt.show()\n",
|
| 1548 |
+
"\n",
|
| 1549 |
"# 1. GRPO training reward (across all rounds)\n",
|
| 1550 |
"if reward_log:\n",
|
| 1551 |
" plt.figure(figsize=(8,4))\n",
|
| 1552 |
" plt.plot(reward_log, label='GRPO mean reward per logging step')\n",
|
| 1553 |
" plt.xlabel('Logging step (across all defender rounds)')\n",
|
| 1554 |
" plt.ylabel('Reward')\n",
|
| 1555 |
+
" plt.title('Stage 2 — GRPO defender training reward')\n",
|
| 1556 |
" plt.legend()\n",
|
| 1557 |
" plt.tight_layout()\n",
|
| 1558 |
" plt.savefig('artifacts/grpo_reward_curve.png', dpi=140)\n",
|
|
|
|
| 1570 |
" plt.savefig('artifacts/grpo_training_loss.png', dpi=140)\n",
|
| 1571 |
" plt.show()\n",
|
| 1572 |
"\n",
|
| 1573 |
+
"# 2b. (Optional) Fraud-LLM GRPO loss + reward — only when USE_LLM_FRAUD=True\n",
|
| 1574 |
+
"if USE_LLM_FRAUD and fraud_llm is not None and fraud_llm.loss_history:\n",
|
| 1575 |
+
" fig, ax1 = plt.subplots(figsize=(8,4))\n",
|
| 1576 |
+
" ax1.plot(fraud_llm.loss_history, color='#c44', label='Fraud-LoRA GRPO loss')\n",
|
| 1577 |
+
" ax1.set_xlabel('Logging step (across all fraud rounds)')\n",
|
| 1578 |
+
" ax1.set_ylabel('Loss', color='#c44')\n",
|
| 1579 |
+
" if fraud_llm.reward_history:\n",
|
| 1580 |
+
" ax2 = ax1.twinx()\n",
|
| 1581 |
+
" ax2.plot(fraud_llm.reward_history, color='#48a',\n",
|
| 1582 |
+
" label='Fraud-LoRA GRPO reward (1 - def_reward)')\n",
|
| 1583 |
+
" ax2.set_ylabel('Reward', color='#48a')\n",
|
| 1584 |
+
" plt.title('Stage 2 — Fraud LoRA GRPO (dual-LLM mode)')\n",
|
| 1585 |
+
" fig.tight_layout()\n",
|
| 1586 |
+
" plt.savefig('artifacts/fraud_llm_grpo_curves.png', dpi=140)\n",
|
| 1587 |
+
" plt.show()\n",
|
| 1588 |
+
"\n",
|
| 1589 |
"# 3. Co-evolution: defender reward vs fraud fitness per round\n",
|
| 1590 |
"rounds_x = np.arange(1, len(defender_round_rewards) + 1)\n",
|
| 1591 |
"fig, ax1 = plt.subplots(figsize=(8,4))\n",
|
|
|
|
| 1628 |
" plt.savefig('artifacts/fraud_theta_trajectory.png', dpi=140)\n",
|
| 1629 |
" plt.show()\n",
|
| 1630 |
"\n",
|
| 1631 |
+
"# 6. Before vs After ── FIX D ──\n",
|
| 1632 |
+
"# Now shows FOUR bars so the comparison is fair AND informative:\n",
|
| 1633 |
+
"# * Random / Heuristic — baselines, eval'd vs neutral fraud (Fix A)\n",
|
| 1634 |
+
"# * Trained LLM (vs Neutral) — apples-to-apples with baselines (PRIMARY)\n",
|
| 1635 |
+
"# * Trained LLM (vs Co-Evo) — robustness against the hardest fraud seen\n",
|
| 1636 |
+
"labels = ['Random\\n(neutral)', 'Heuristic\\n(neutral)',\n",
|
| 1637 |
+
" 'Trained LLM\\n(neutral)', 'Trained LLM\\n(co-evolved)']\n",
|
| 1638 |
+
"values = [\n",
|
| 1639 |
+
" baseline_random['mean_reward'],\n",
|
| 1640 |
+
" baseline_heuristic['mean_reward'],\n",
|
| 1641 |
+
" trained_eval_neutral['mean_reward'],\n",
|
| 1642 |
+
" trained_eval['mean_reward'],\n",
|
| 1643 |
+
"]\n",
|
| 1644 |
+
"colors = ['#bbb','#88c','#4a8','#268']\n",
|
| 1645 |
+
"plt.figure(figsize=(8.5, 4.5))\n",
|
| 1646 |
+
"bars = plt.bar(labels, values, color=colors)\n",
|
| 1647 |
"for b, v in zip(bars, values):\n",
|
| 1648 |
" plt.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.3f}', ha='center')\n",
|
| 1649 |
"plt.ylabel('Mean reward (frozen holdout)')\n",
|
| 1650 |
+
"plt.title('Before vs After Training (SFT + GRPO ladder co-evolution)')\n",
|
| 1651 |
"plt.tight_layout()\n",
|
| 1652 |
"plt.savefig('artifacts/before_after_rewards.png', dpi=140)\n",
|
| 1653 |
"plt.show()\n",
|
| 1654 |
"\n",
|
| 1655 |
+
"# 7a. Trained defender vs each LEAGUE rung (ladder robustness)\n",
|
| 1656 |
+
"# A \"good\" ladder run shows the trained defender scoring at-or-above the\n",
|
| 1657 |
+
"# heuristic baseline across ALL rungs (not just the latest). A spike on the\n",
|
| 1658 |
+
"# last rung only would be evidence of catastrophic forgetting.\n",
|
| 1659 |
+
"if league_eval_rewards:\n",
|
| 1660 |
+
" rung_names = [r['name'] for r in league_eval_rewards]\n",
|
| 1661 |
+
" rung_rewards = [r['defender_reward'] for r in league_eval_rewards]\n",
|
| 1662 |
+
" plt.figure(figsize=(8.5, 4))\n",
|
| 1663 |
+
" bars = plt.bar(rung_names, rung_rewards, color='#4a8')\n",
|
| 1664 |
+
" for b, v in zip(bars, rung_rewards):\n",
|
| 1665 |
+
" plt.text(b.get_x()+b.get_width()/2, v+0.005, f'{v:.3f}',\n",
|
| 1666 |
+
" ha='center', fontsize=9)\n",
|
| 1667 |
+
" plt.axhline(baseline_heuristic['mean_reward'], color='#88c',\n",
|
| 1668 |
+
" linestyle='--', label=f\"Heuristic (neutral): {baseline_heuristic['mean_reward']:.3f}\")\n",
|
| 1669 |
+
" plt.axhline(baseline_random['mean_reward'], color='#aaa',\n",
|
| 1670 |
+
" linestyle=':', label=f\"Random (neutral): {baseline_random['mean_reward']:.3f}\")\n",
|
| 1671 |
+
" plt.xticks(rotation=20, ha='right', fontsize=8)\n",
|
| 1672 |
+
" plt.ylabel('Trained defender mean reward')\n",
|
| 1673 |
+
" plt.title('Ladder robustness: Trained LLM vs each league rung')\n",
|
| 1674 |
+
" plt.legend(fontsize=8)\n",
|
| 1675 |
+
" plt.tight_layout()\n",
|
| 1676 |
+
" plt.savefig('artifacts/league_robustness.png', dpi=140)\n",
|
| 1677 |
+
" plt.show()\n",
|
| 1678 |
+
"\n",
|
| 1679 |
+
"# 7. Per risk-bucket ── FIX D ──\n",
|
| 1680 |
+
"# Same 4-way comparison broken out by Low / Medium / High risk so you can\n",
|
| 1681 |
+
"# see if the trained model lifts performance in the hard buckets where\n",
|
| 1682 |
+
"# heuristic + random give up.\n",
|
| 1683 |
"buckets = ['low', 'medium', 'high']\n",
|
| 1684 |
+
"rand_b = [baseline_random['bucket_means'][b] for b in buckets]\n",
|
| 1685 |
+
"heur_b = [baseline_heuristic['bucket_means'][b] for b in buckets]\n",
|
| 1686 |
+
"trN_b = [trained_eval_neutral['bucket_means'][b] for b in buckets]\n",
|
| 1687 |
+
"trC_b = [trained_eval['bucket_means'][b] for b in buckets]\n",
|
| 1688 |
"x = np.arange(len(buckets))\n",
|
| 1689 |
+
"w = 0.20\n",
|
| 1690 |
+
"plt.figure(figsize=(9.5, 4.5))\n",
|
| 1691 |
+
"plt.bar(x - 1.5*w, rand_b, width=w, label='Random (neutral)', color='#bbb')\n",
|
| 1692 |
+
"plt.bar(x - 0.5*w, heur_b, width=w, label='Heuristic (neutral)', color='#88c')\n",
|
| 1693 |
+
"plt.bar(x + 0.5*w, trN_b, width=w, label='Trained LLM (neutral)', color='#4a8')\n",
|
| 1694 |
+
"plt.bar(x + 1.5*w, trC_b, width=w, label='Trained LLM (co-evolved)', color='#268')\n",
|
| 1695 |
"plt.xticks(x, [b.title()+' Risk' for b in buckets])\n",
|
| 1696 |
"plt.ylabel('Mean reward')\n",
|
| 1697 |
"plt.title('Per Risk-Bucket Reward (frozen holdout)')\n",
|
| 1698 |
+
"plt.legend(loc='best', fontsize=8)\n",
|
| 1699 |
"plt.tight_layout()\n",
|
| 1700 |
"plt.savefig('artifacts/per_bucket_rewards.png', dpi=140)\n",
|
| 1701 |
"plt.show()\n",
|
|
|
|
| 1705 |
" 'model_id': MODEL_ID,\n",
|
| 1706 |
" 'quick_mode': QUICK_MODE,\n",
|
| 1707 |
" 'prompts_used': len(prompts),\n",
|
| 1708 |
+
" 'training_recipe': 'SFT(heuristic-imitation) -> ladder GRPO(rung-curriculum) ⇄ ES fraud (PFSP league)',\n",
|
| 1709 |
+
" 'sft_steps': SFT_STEPS,\n",
|
| 1710 |
+
" 'sft_lr': SFT_LR,\n",
|
| 1711 |
+
" 'sft_loss_history': sft_loss_history,\n",
|
| 1712 |
" 'grpo_num_generations': GRPO_NUM_GENERATIONS,\n",
|
| 1713 |
" 'rollout_steps_per_reward': ROLLOUT_STEPS_PER_REWARD,\n",
|
| 1714 |
" 'n_rounds': N_ROUNDS,\n",
|
| 1715 |
" 'grpo_steps_per_round': GRPO_STEPS_PER_ROUND,\n",
|
| 1716 |
" 'es_steps_per_round': ES_STEPS_PER_ROUND,\n",
|
| 1717 |
" 'es_population': ES_POPULATION,\n",
|
| 1718 |
+
" 'ladder_rungs': LADDER_RUNGS,\n",
|
| 1719 |
+
" 'ladder_round_rung': ladder_round_rung,\n",
|
| 1720 |
+
" 'league_pool': [r['name'] for r in league.rungs],\n",
|
| 1721 |
+
" 'league_past_sample_prob': LEAGUE_PAST_SAMPLE_PROB,\n",
|
| 1722 |
+
" 'league_eval_rewards': league_eval_rewards,\n",
|
| 1723 |
+
" 'use_llm_fraud': USE_LLM_FRAUD,\n",
|
| 1724 |
+
" 'fraud_llm_grpo_loss_history': (fraud_llm.loss_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
|
| 1725 |
+
" 'fraud_llm_grpo_reward_history': (fraud_llm.reward_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
|
| 1726 |
+
" 'fraud_llm_theta_history': (fraud_llm.theta_history if (USE_LLM_FRAUD and fraud_llm is not None) else []),\n",
|
| 1727 |
" 'baseline_random_mean_reward': baseline_random['mean_reward'],\n",
|
| 1728 |
" 'baseline_heuristic_mean_reward': baseline_heuristic['mean_reward'],\n",
|
| 1729 |
+
" 'trained_mean_reward_neutral_fraud': trained_eval_neutral['mean_reward'],\n",
|
| 1730 |
+
" 'trained_mean_reward_coevolved_fraud': trained_eval['mean_reward'],\n",
|
| 1731 |
+
" 'reward_gain_vs_random': trained_eval_neutral['mean_reward'] - baseline_random['mean_reward'],\n",
|
| 1732 |
+
" 'reward_gain_vs_heuristic': trained_eval_neutral['mean_reward'] - baseline_heuristic['mean_reward'],\n",
|
| 1733 |
" 'per_bucket': {\n",
|
| 1734 |
+
" 'random': baseline_random['bucket_means'],\n",
|
| 1735 |
+
" 'heuristic': baseline_heuristic['bucket_means'],\n",
|
| 1736 |
+
" 'trained_neutral': trained_eval_neutral['bucket_means'],\n",
|
| 1737 |
+
" 'trained_coevolved': trained_eval['bucket_means'],\n",
|
| 1738 |
" },\n",
|
| 1739 |
" 'defender_round_rewards': defender_round_rewards,\n",
|
| 1740 |
" 'fraud_round_fitness': fraud_round_fitness,\n",
|
|
|
|
| 1744 |
" 'grpo_reward_curve': reward_log,\n",
|
| 1745 |
" 'grpo_loss_history': loss_history,\n",
|
| 1746 |
" 'eval_per_episode': {\n",
|
| 1747 |
+
" 'random': baseline_random['per_episode_mean'],\n",
|
| 1748 |
+
" 'heuristic': baseline_heuristic['per_episode_mean'],\n",
|
| 1749 |
+
" 'trained_neutral': trained_eval_neutral['per_episode_mean'],\n",
|
| 1750 |
+
" 'trained_coevolved': trained_eval['per_episode_mean'],\n",
|
| 1751 |
" },\n",
|
| 1752 |
"}\n",
|
| 1753 |
"with open('artifacts/run_summary.json', 'w', encoding='utf-8') as f:\n",
|
server/SmartPayEnv_environment.py
CHANGED
|
@@ -504,8 +504,14 @@ class SmartpayenvEnvironment(Environment):
|
|
| 504 |
base_reward = (0.4 * route_score) + (0.4 * fs) + (0.2 * rs)
|
| 505 |
|
| 506 |
# League-style regret: penalize underperforming against moving challenger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
challenger_regret = max(0.0, self._state.challenger_skill - base_reward)
|
| 508 |
-
regret_penalty = 0.
|
| 509 |
|
| 510 |
# Anti-gaming check: repeatedly overusing manual review without quality gains.
|
| 511 |
gaming_penalty = 0.0
|
|
@@ -513,8 +519,13 @@ class SmartpayenvEnvironment(Environment):
|
|
| 513 |
self._state.anti_gaming_alerts += 1
|
| 514 |
gaming_penalty = min(0.12, 0.02 * self._state.anti_gaming_alerts)
|
| 515 |
|
| 516 |
-
# Curriculum bonus: reward robust performance
|
| 517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
# Norm punishment for delayed liabilities + self-improvement terms.
|
| 520 |
final_reward = base_reward - (cb_amt / 150.0) - regret_penalty - gaming_penalty + robustness_bonus
|
|
|
|
| 504 |
base_reward = (0.4 * route_score) + (0.4 * fs) + (0.2 * rs)
|
| 505 |
|
| 506 |
# League-style regret: penalize underperforming against moving challenger.
|
| 507 |
+
# NOTE: coefficient was 0.35 — too crushing as a learning signal. A fresh
|
| 508 |
+
# GRPO policy with base_reward=0.3 would lose ~0.12 here, while a strong
|
| 509 |
+
# policy with base_reward=0.7 lost almost nothing. That's the wrong slope:
|
| 510 |
+
# it punished bad policies more than good ones, suppressing the gradient
|
| 511 |
+
# at the very start of training. 0.15 keeps the league-style pressure but
|
| 512 |
+
# leaves enough reward range for early learning.
|
| 513 |
challenger_regret = max(0.0, self._state.challenger_skill - base_reward)
|
| 514 |
+
regret_penalty = 0.15 * challenger_regret
|
| 515 |
|
| 516 |
# Anti-gaming check: repeatedly overusing manual review without quality gains.
|
| 517 |
gaming_penalty = 0.0
|
|
|
|
| 519 |
self._state.anti_gaming_alerts += 1
|
| 520 |
gaming_penalty = min(0.12, 0.02 * self._state.anti_gaming_alerts)
|
| 521 |
|
| 522 |
+
# Curriculum bonus: reward robust performance.
|
| 523 |
+
# NOTE: was `0.06 * curriculum_level * ...` which is exactly 0.0 until the
|
| 524 |
+
# self-improvement loop has already lifted curriculum_level above 0 —
|
| 525 |
+
# a chicken-and-egg that gave bad policies no upside signal at all. The
|
| 526 |
+
# `(1.0 + curriculum_level)` factor activates the bonus from step 1
|
| 527 |
+
# (worth +0.10 * (base-0.5) immediately) and *grows* with curriculum.
|
| 528 |
+
robustness_bonus = 0.10 * (1.0 + self._state.curriculum_level) * max(0.0, base_reward - 0.5)
|
| 529 |
|
| 530 |
# Norm punishment for delayed liabilities + self-improvement terms.
|
| 531 |
final_reward = base_reward - (cb_amt / 150.0) - regret_penalty - gaming_penalty + robustness_bonus
|