rtferraz commited on
Commit
6c51e5f
·
verified ·
1 Parent(s): a6a8b11

feat: add v3 notebook (.ipynb) — ready for Vertex AI Workbench

Browse files
Files changed (1) hide show
  1. notebooks/grpo_vertex_v3.ipynb +1517 -0
notebooks/grpo_vertex_v3.ipynb ADDED
@@ -0,0 +1,1517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3 (ipykernel)",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10.0",
13
+ "mimetype": "text/x-python",
14
+ "file_extension": ".py"
15
+ }
16
+ },
17
+ "cells": [
18
+ {
19
+ "cell_type": "markdown",
20
+ "metadata": {},
21
+ "source": [
22
+ "# Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)\n",
23
+ "\n",
24
+ "**v3 changes over v2 — grounded in published research:**\n",
25
+ "\n",
26
+ "| Change | v2 Value | v3 Value | Paper Reference |\n",
27
+ "|--------|----------|----------|----------------|\n",
28
+ "| Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |\n",
29
+ "| Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |\n",
30
+ "| Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |\n",
31
+ "| Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |\n",
32
+ "| β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |\n",
33
+ "| Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |\n",
34
+ "| Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |\n",
35
+ "| Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |\n",
36
+ "| Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |\n",
37
+ "| Early stopping patience | 10 | **15** | More runway for longer completions |\n",
38
+ "| Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |\n",
39
+ "| Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |\n",
40
+ "| General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |\n",
41
+ "\n",
42
+ "**Prerequisites:**\n",
43
+ "- Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`\n",
44
+ "- Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`\n",
45
+ "- **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix\n",
46
+ "\n",
47
+ "**Hardware:** L4 (24GB), PyTorch kernel, bf16 supported\n",
48
+ "\n",
49
+ "---\n",
50
+ "\n",
51
+ "## Cell 1: Dependencies\n",
52
+ "\n",
53
+ "Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "# Cell 1a — Nuke everything ML-related\n",
63
+ "!pip uninstall -y torch torchvision torchaudio \\\n",
64
+ " unsloth unsloth-zoo \\\n",
65
+ " trl transformers peft accelerate \\\n",
66
+ " bitsandbytes vllm vllm-flash-attn \\\n",
67
+ " datasets tokenizers safetensors huggingface-hub \\\n",
68
+ " wandb xformers triton \\\n",
69
+ " cuda-bindings cuda-python \\\n",
70
+ " sentencepiece protobuf \\\n",
71
+ " 2>/dev/null"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# Cell 1b — Kill any stragglers\n",
81
+ "!pip freeze | grep -iE \"torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate\" | xargs pip uninstall -y 2>/dev/null"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "# Cell 1c — Purge cache\n",
91
+ "!pip cache purge"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "markdown",
96
+ "metadata": {},
97
+ "source": [
98
+ "**⚠️ Restart kernel again**, then:"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "# Cell 1d — Clean install, correct order\n",
108
+ "!pip install \"unsloth\""
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Cell 1e — Pin TRL (Unsloth may pull a different version)\n",
118
+ "!pip install \"trl==0.24.0\" --no-deps"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "# Cell 1f — Extra deps\n",
128
+ "!pip install \"rich\" \"wandb\""
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "metadata": {},
134
+ "source": [
135
+ "---\n",
136
+ "\n",
137
+ "## Cell 2: Hello World — GPU + Unsloth Verification"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "import torch\n",
147
+ "\n",
148
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
149
+ "print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
150
+ "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n",
151
+ "print(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n",
152
+ "\n",
153
+ "from unsloth import FastLanguageModel\n",
154
+ "print(\"\\n✓ Unsloth loaded successfully\")\n",
155
+ "\n",
156
+ "import trl\n",
157
+ "print(f\"✓ TRL version: {trl.__version__}\")\n",
158
+ "\n",
159
+ "import transformers\n",
160
+ "print(f\"✓ Transformers version: {transformers.__version__}\")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {},
166
+ "source": [
167
+ "---\n",
168
+ "\n",
169
+ "## Cell 3: Config + Constants"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "import os\n",
179
+ "os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\n",
180
+ "\n",
181
+ "import json\n",
182
+ "import re\n",
183
+ "import time\n",
184
+ "import random\n",
185
+ "import gc\n",
186
+ "from pathlib import Path\n",
187
+ "\n",
188
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
189
+ "# v3 CONFIG — Every change is annotated with paper reference\n",
190
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
191
+ "\n",
192
+ "MODEL_ID = \"Polygl0t/Tucano2-qwen-3.7B-Think\"\n",
193
+ "MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt\n",
194
+ "\n",
195
+ "# ── Paths ─────────────────────────────────────────────────────────────────────\n",
196
+ "DATA_DIR = Path(\"/home/jupyter/tucano2/data\")\n",
197
+ "MODELS_DIR = Path(\"/home/jupyter/tucano2/models\")\n",
198
+ "SFT_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-sft\"\n",
199
+ "GRPO_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-grpo-v3\" # v3: separate dir from v2\n",
200
+ "CHECKPOINT_DIR = GRPO_ADAPTER_DIR / \"checkpoints\"\n",
201
+ "\n",
202
+ "# ── Training data ─────────────────────────────────────────────────────────────\n",
203
+ "GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)\n",
204
+ "GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)\n",
205
+ "\n",
206
+ "# ── Valid enums for reward scoring (unchanged from v2) ────────────────────────\n",
207
+ "VALID_SENTIMENTS = {\"positive\", \"negative\", \"neutral\"}\n",
208
+ "VALID_CATEGORIES = {\n",
209
+ " \"delivery_delay\", \"product_quality\", \"product_not_received\",\n",
210
+ " \"wrong_product\", \"seller_communication\", \"app_issue\",\n",
211
+ " \"price_value\", \"other\", \"none\",\n",
212
+ "}\n",
213
+ "VALID_CHURN = {\"low\", \"medium\", \"high\"}\n",
214
+ "VALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\n",
215
+ "EXTRACTION_FIELDS = [\n",
216
+ " \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n",
217
+ " \"product_issue\", \"seller_issue\", \"main_complaint\",\n",
218
+ " \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n",
219
+ "]\n",
220
+ "\n",
221
+ "SYSTEM_PT = (\n",
222
+ " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
223
+ " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\"\n",
224
+ ")\n",
225
+ "\n",
226
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
227
+ "# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)\n",
228
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
229
+ "\n",
230
+ "# ── Core GRPO params ──────────────────────────────────────────────────────────\n",
231
+ "BATCH_SIZE = 4\n",
232
+ "GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)\n",
233
+ " # With G=4: steps = prompts × 4 / 4 = prompts per epoch\n",
234
+ "NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions\n",
235
+ " # MC-GRPO (2601.22582): G=4 works if noise is mitigated\n",
236
+ "SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias\n",
237
+ "\n",
238
+ "# ── v3 CRITICAL FIXES ────────────────────────────────────────────────────────\n",
239
+ "\n",
240
+ "# FIX 1: Temperature — prevent entropy collapse\n",
241
+ "# v2 had 0.8. All published GRPO papers use 1.0.\n",
242
+ "# Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance\n",
243
+ "TEMPERATURE = 1.0\n",
244
+ "\n",
245
+ "# FIX 2: Completion length — remove the ceiling\n",
246
+ "# v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.\n",
247
+ "# Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient\n",
248
+ "MAX_COMPLETION_LENGTH = 4096\n",
249
+ "\n",
250
+ "# FIX 3: Learning rate — more aggressive\n",
251
+ "# v2: clip_ratio=0 on all steps → updates were too small to matter\n",
252
+ "# Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.\n",
253
+ "# We go 2× since v2 showed zero clipping (model can absorb stronger push)\n",
254
+ "LEARNING_RATE = 2e-6\n",
255
+ "\n",
256
+ "# FIX 4: β = 0 (no KL penalty)\n",
257
+ "# Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards\n",
258
+ "# v2 used implicit KL through default β — we explicitly disable it\n",
259
+ "BETA = 0.0\n",
260
+ "\n",
261
+ "# ── Training schedule ─────────────────────────────────────────────────────────\n",
262
+ "NUM_EPOCHS = 1\n",
263
+ "MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed\n",
264
+ " # With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch\n",
265
+ " # MAX_STEPS=500 < 1 epoch — early stopping or manual extension\n",
266
+ "\n",
267
+ "# ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────\n",
268
+ "EVAL_SPLIT_RATIO = 0.15\n",
269
+ "EVAL_STEPS = 10\n",
270
+ "EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway\n",
271
+ "EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains\n",
272
+ "SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again\n",
273
+ "SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)\n",
274
+ "WANDB_PROJECT = \"tucano2-commerce\"\n",
275
+ "\n",
276
+ "# ── Eval callback ─────────────────────────────────────────────────────────────\n",
277
+ "EVAL_MAX_SAMPLES = 5\n",
278
+ "EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)\n",
279
+ "EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)\n",
280
+ "\n",
281
+ "# ── Backend ───────────────────────────────────────────────────────────────────\n",
282
+ "USE_VLLM = False\n",
283
+ "\n",
284
+ "# ── v3: Zero-advantage noise injection ────────────────────────────────────────\n",
285
+ "# Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training\n",
286
+ "# When all G completions get identical rewards, the advantage is undefined.\n",
287
+ "# Noise injection breaks ties without corrupting the signal.\n",
288
+ "ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups\n",
289
+ "\n",
290
+ "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
291
+ "\n",
292
+ "# ── Version assertion ─────────────────────────────────────────────────────────\n",
293
+ "import trl as _trl\n",
294
+ "assert _trl.__version__ == \"0.24.0\", (\n",
295
+ " f\"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\\n\"\n",
296
+ " \"Verify that GRPOTrainer._generate() still exists before proceeding.\"\n",
297
+ ")\n",
298
+ "\n",
299
+ "print(\"✓ v3 Config loaded\")\n",
300
+ "print(f\" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})\")\n",
301
+ "print(f\" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})\")\n",
302
+ "print(f\" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}\")\n",
303
+ "print(f\" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}\")\n",
304
+ "print(f\" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)\")\n",
305
+ "print(f\" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}\")\n",
306
+ "print(f\"✓ TRL {_trl.__version__} verified\")\n",
307
+ "\n",
308
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
309
+ "# v3 VRAM BUDGET (L4 24GB)\n",
310
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
311
+ "# Model (NF4): ~3.5 GB\n",
312
+ "# KV Cache (8192 seq): ~3.0 GB\n",
313
+ "# Activations: ~4.0 GB\n",
314
+ "# Optimizer states: ~3.0 GB\n",
315
+ "# Generations (4×4096): ~8.0 GB\n",
316
+ "# ─────────────────────────────────\n",
317
+ "# Estimated total: ~21.5 GB\n",
318
+ "# Headroom: ~2.5 GB\n",
319
+ "#\n",
320
+ "# If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.\n",
321
+ "# Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.\n",
322
+ "# ══════════════════════════════════════════════════════════════════════════════"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {},
328
+ "source": [
329
+ "---\n",
330
+ "\n",
331
+ "## Cell 4: Load SFT Adapter"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "print(\"Loading SFT adapter...\")\n",
341
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
342
+ " model_name=str(SFT_ADAPTER_DIR),\n",
343
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
344
+ " load_in_4bit=True,\n",
345
+ " dtype=None,\n",
346
+ ")\n",
347
+ "\n",
348
+ "if tokenizer.pad_token is None:\n",
349
+ " tokenizer.pad_token = tokenizer.eos_token\n",
350
+ "\n",
351
+ "# Load chat template from base model (SFT adapter doesn't save it)\n",
352
+ "from transformers import AutoTokenizer\n",
353
+ "base_tok = AutoTokenizer.from_pretrained(MODEL_ID)\n",
354
+ "tokenizer.chat_template = base_tok.chat_template\n",
355
+ "del base_tok\n",
356
+ "\n",
357
+ "# v2: Force KV cache — Unsloth patching may reset this\n",
358
+ "model.config.use_cache = True\n",
359
+ "model.generation_config.use_cache = True\n",
360
+ "\n",
361
+ "print(f\"✓ Model loaded on {model.device}\")\n",
362
+ "print(f\" use_cache: {model.config.use_cache}\")\n",
363
+ "print(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n",
364
+ "print(f\" Chat template: {tokenizer.chat_template[:50]}...\")"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "---\n",
372
+ "\n",
373
+ "## Cell 5: Single Inference Test\n",
374
+ "\n",
375
+ "**Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": null,
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "FastLanguageModel.for_inference(model)\n",
385
+ "\n",
386
+ "test_msgs = [\n",
387
+ " {\"role\": \"system\", \"content\": SYSTEM_PT},\n",
388
+ " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
389
+ "]\n",
390
+ "text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)\n",
391
+ "inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
392
+ "\n",
393
+ "t0 = time.time()\n",
394
+ "outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
395
+ "elapsed = time.time() - t0\n",
396
+ "\n",
397
+ "response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
398
+ "gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
399
+ "\n",
400
+ "print(f\"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)\")\n",
401
+ "print(f\"Response length: {len(response)} chars, {gen_tokens} tokens\")\n",
402
+ "print(f\"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}\") # v3: should NOT hit ceiling with 4096\n",
403
+ "print(f\"closed_think: {'</think>' in response}\")\n",
404
+ "print(f\"\\n{'='*60}\")\n",
405
+ "print(response[:800])"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "markdown",
410
+ "metadata": {},
411
+ "source": [
412
+ "---\n",
413
+ "\n",
414
+ "## Cell 5b: KV Cache Diagnostic"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": [
423
+ "import time\n",
424
+ "FastLanguageModel.for_inference(model)\n",
425
+ "\n",
426
+ "_kv_msgs = [{\"role\": \"system\", \"content\": SYSTEM_PT},\n",
427
+ " {\"role\": \"user\", \"content\": \"Qual a categoria de reclamação mais frequente?\"}]\n",
428
+ "_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n",
429
+ "_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n",
430
+ "\n",
431
+ "_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\n",
432
+ "with torch.no_grad():\n",
433
+ " for _step in range(50):\n",
434
+ " _t0 = time.time()\n",
435
+ " seq_len = _generated.shape[1]\n",
436
+ " if _past is None:\n",
437
+ " _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n",
438
+ " else:\n",
439
+ " _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n",
440
+ " _out = model(\n",
441
+ " input_ids=_generated[:, -1:] if _past else _generated,\n",
442
+ " position_ids=_position_ids,\n",
443
+ " attention_mask=torch.ones(1, seq_len, device=model.device),\n",
444
+ " past_key_values=_past,\n",
445
+ " use_cache=True,\n",
446
+ " return_dict=True,\n",
447
+ " )\n",
448
+ " _past = _out.past_key_values\n",
449
+ " _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n",
450
+ " _generated = torch.cat([_generated, _next], dim=1)\n",
451
+ " _token_times.append(time.time() - _t0)\n",
452
+ "\n",
453
+ "_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\n",
454
+ "print(f\"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\n",
455
+ "print(f\"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\n",
456
+ "print(f\"Ratio last/first: {_ratio:.1f}x\")\n",
457
+ "if _ratio < 3:\n",
458
+ " print(\"✓ KV cache is working correctly\")\n",
459
+ "elif _ratio < 6:\n",
460
+ " print(\"⚠ KV cache may be degraded — check model.config.use_cache\")\n",
461
+ "else:\n",
462
+ " print(\"✗ KV cache BROKEN — GRPO generation will be catastrophically slow.\")\n",
463
+ "\n",
464
+ "del _past, _generated, _kv_inputs, _token_times, _out\n",
465
+ "gc.collect()\n",
466
+ "if torch.cuda.is_available(): torch.cuda.empty_cache()"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "metadata": {},
472
+ "source": [
473
+ "---\n",
474
+ "\n",
475
+ "## Cell 6: Reward Functions v3\n",
476
+ "\n",
477
+ "**v3 changes:**\n",
478
+ "- Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)\n",
479
+ "- Zero-advantage noise injection (Skywork-OR1, 2505.22312)\n",
480
+ "- Extraction reward redesigned for completion-length-friendly scoring"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": [
489
+ "def strip_think(text: str) -> str:\n",
490
+ " \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n",
491
+ " return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n",
492
+ "\n",
493
+ "\n",
494
+ "def has_think_block(text: str) -> bool:\n",
495
+ " \"\"\"Check if text contains a non-empty <think> block.\"\"\"\n",
496
+ " return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n",
497
+ "\n",
498
+ "\n",
499
+ "def _classify_task_type(prompt_text: str) -> str:\n",
500
+ " \"\"\"Classify prompt into task type by keywords.\"\"\"\n",
501
+ " p = prompt_text.lower()\n",
502
+ " if \"retorne um objeto json\" in p or \"extraia dados\" in p:\n",
503
+ " return \"extraction\"\n",
504
+ " elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
505
+ " return \"push\"\n",
506
+ " elif \"perfil do cliente\" in p:\n",
507
+ " return \"insights\"\n",
508
+ " else:\n",
509
+ " return \"sql_qa\"\n",
510
+ "\n",
511
+ "\n",
512
+ "def _json_similarity(text: str) -> float:\n",
513
+ " \"\"\"Rough heuristic: how JSON-like is this text? 0.0 to 1.0.\"\"\"\n",
514
+ " text = text.strip()\n",
515
+ " if not text:\n",
516
+ " return 0.0\n",
517
+ " score = 0.0\n",
518
+ " if text.startswith(\"{\") and text.endswith(\"}\"):\n",
519
+ " score += 0.5\n",
520
+ " if '\"' in text:\n",
521
+ " score += 0.2\n",
522
+ " if \":\" in text:\n",
523
+ " score += 0.2\n",
524
+ " if \",\" in text:\n",
525
+ " score += 0.1\n",
526
+ " return min(score, 1.0)\n",
527
+ "\n",
528
+ "\n",
529
+ "def _string_similarity(a: str, b: str) -> float:\n",
530
+ " \"\"\"Simple Jaccard-like similarity for short strings. 0.0 to 1.0.\"\"\"\n",
531
+ " if not a or not b:\n",
532
+ " return 0.0\n",
533
+ " a_set = set(a.split())\n",
534
+ " b_set = set(b.split())\n",
535
+ " intersection = len(a_set & b_set)\n",
536
+ " union = len(a_set | b_set)\n",
537
+ " return intersection / union if union > 0 else 0.0\n",
538
+ "\n",
539
+ "\n",
540
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
541
+ "# v3 STAGED REWARD DESIGN\n",
542
+ "# Reference: Reasoning-SQL (2503.23157) §3.2\n",
543
+ "#\n",
544
+ "# Each reward function scores THREE stages independently:\n",
545
+ "# Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?\n",
546
+ "# Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?\n",
547
+ "# Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?\n",
548
+ "#\n",
549
+ "# Format rewards converge first (easy to learn), which stabilizes training\n",
550
+ "# and enables the model to then learn harder task-specific skills.\n",
551
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
552
+ "\n",
553
+ "\n",
554
+ "def reward_extraction(completion: str) -> float:\n",
555
+ " \"\"\"Staged reward for structured extraction (max 1.0).\"\"\"\n",
556
+ " answer = strip_think(completion)\n",
557
+ "\n",
558
+ " # ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────\n",
559
+ " r_format = 0.0\n",
560
+ " if has_think_block(completion):\n",
561
+ " r_format += 0.1 # Used reasoning\n",
562
+ "\n",
563
+ " try:\n",
564
+ " data = json.loads(answer)\n",
565
+ " if isinstance(data, dict):\n",
566
+ " r_format += 0.1 # Valid JSON object\n",
567
+ " except (json.JSONDecodeError, TypeError):\n",
568
+ " r_format += 0.05 * _json_similarity(answer)\n",
569
+ " return min(r_format, 0.2)\n",
570
+ "\n",
571
+ " if not isinstance(data, dict):\n",
572
+ " return min(r_format, 0.2)\n",
573
+ "\n",
574
+ " # ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────\n",
575
+ " r_partial = 0.0\n",
576
+ "\n",
577
+ " present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n",
578
+ " r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))\n",
579
+ "\n",
580
+ " type_checks = 0\n",
581
+ " type_total = 0\n",
582
+ " for field in EXTRACTION_FIELDS:\n",
583
+ " if field not in data:\n",
584
+ " continue\n",
585
+ " type_total += 1\n",
586
+ " val = data[field]\n",
587
+ " if field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
588
+ " if isinstance(val, bool):\n",
589
+ " type_checks += 1\n",
590
+ " elif field in (\"sentiment_score\",):\n",
591
+ " if isinstance(val, (int, float)):\n",
592
+ " type_checks += 1\n",
593
+ " elif field in (\"main_complaint\", \"sentiment\", \"complaint_category\", \"churn_risk\", \"repeat_intent\"):\n",
594
+ " if isinstance(val, str):\n",
595
+ " type_checks += 1\n",
596
+ " if type_total > 0:\n",
597
+ " r_partial += 0.15 * (type_checks / type_total)\n",
598
+ "\n",
599
+ " # ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────\n",
600
+ " r_task = 0.0\n",
601
+ " cat_checks = 0\n",
602
+ " cat_total = 0\n",
603
+ "\n",
604
+ " checks = [\n",
605
+ " (\"sentiment\", lambda v: v in VALID_SENTIMENTS),\n",
606
+ " (\"complaint_category\", lambda v: v in VALID_CATEGORIES),\n",
607
+ " (\"churn_risk\", lambda v: v in VALID_CHURN),\n",
608
+ " (\"repeat_intent\", lambda v: v in VALID_REPEAT),\n",
609
+ " (\"sentiment_score\", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),\n",
610
+ " ]\n",
611
+ " for field, validator in checks:\n",
612
+ " cat_total += 1\n",
613
+ " if field in data and validator(data[field]):\n",
614
+ " cat_checks += 1\n",
615
+ "\n",
616
+ " for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
617
+ " cat_total += 1\n",
618
+ " if bool_field in data and isinstance(data[bool_field], bool):\n",
619
+ " cat_checks += 1\n",
620
+ "\n",
621
+ " if cat_total > 0:\n",
622
+ " r_task += 0.35 * (cat_checks / cat_total)\n",
623
+ "\n",
624
+ " if \"main_complaint\" in data and isinstance(data[\"main_complaint\"], str):\n",
625
+ " complaint = data[\"main_complaint\"].strip()\n",
626
+ " if len(complaint) > 10:\n",
627
+ " r_task += 0.15\n",
628
+ "\n",
629
+ " return min(r_format + r_partial + r_task, 1.0)\n",
630
+ "\n",
631
+ "\n",
632
+ "def reward_sql_qa(completion: str) -> float:\n",
633
+ " \"\"\"Staged reward for SQL Q&A (max 1.0).\"\"\"\n",
634
+ " answer = strip_think(completion)\n",
635
+ "\n",
636
+ " # ── Stage 1: FORMAT (max 0.2)\n",
637
+ " r_format = 0.0\n",
638
+ " if has_think_block(completion):\n",
639
+ " r_format += 0.1\n",
640
+ " if \"```\" in answer or re.search(r\"SELECT|FROM\", answer, re.IGNORECASE):\n",
641
+ " r_format += 0.1\n",
642
+ "\n",
643
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
644
+ " r_partial = 0.0\n",
645
+ " sql_keywords = r\"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING\"\n",
646
+ " matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))\n",
647
+ " r_partial += min(0.15, 0.03 * matches)\n",
648
+ " numbers = re.findall(r\"\\d+(?:[.,]\\d+)?\", answer)\n",
649
+ " r_partial += min(0.15, 0.03 * len(numbers))\n",
650
+ "\n",
651
+ " # ── Stage 3: TASK (max 0.5)\n",
652
+ " r_task = 0.0\n",
653
+ " length = len(answer)\n",
654
+ " if 50 <= length <= 600:\n",
655
+ " r_task += 0.25\n",
656
+ " elif length > 0:\n",
657
+ " r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)\n",
658
+ " explanation_markers = [\"para \", \"porque\", \"resultado\", \"mostra\", \"indica\", \"análise\"]\n",
659
+ " expl_matches = sum(1 for w in explanation_markers if w in answer.lower())\n",
660
+ " r_task += min(0.25, 0.05 * expl_matches)\n",
661
+ "\n",
662
+ " return min(r_format + r_partial + r_task, 1.0)\n",
663
+ "\n",
664
+ "\n",
665
+ "def reward_insights(completion: str) -> float:\n",
666
+ " \"\"\"Staged reward for insights (max 1.0).\"\"\"\n",
667
+ " answer = strip_think(completion)\n",
668
+ "\n",
669
+ " # ── Stage 1: FORMAT (max 0.2)\n",
670
+ " r_format = 0.0\n",
671
+ " if has_think_block(completion):\n",
672
+ " r_format += 0.1\n",
673
+ " structure_marks = len(re.findall(r\"^[-•*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n",
674
+ " r_format += min(0.1, 0.02 * structure_marks)\n",
675
+ "\n",
676
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
677
+ " r_partial = 0.0\n",
678
+ " length = len(answer)\n",
679
+ " if 100 <= length <= 1200:\n",
680
+ " r_partial += 0.15\n",
681
+ " elif length > 0:\n",
682
+ " r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)\n",
683
+ " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto\", answer, re.IGNORECASE)\n",
684
+ " r_partial += min(0.15, 0.01 * len(pt_markers))\n",
685
+ "\n",
686
+ " # ── Stage 3: TASK (max 0.5)\n",
687
+ " r_task = 0.0\n",
688
+ " action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n",
689
+ " \"priorizar\", \"investir\", \"otimizar\", \"estratégi\", \"suger\",\n",
690
+ " \"consider\", \"ação\", \"plano\"]\n",
691
+ " matches = sum(1 for w in action_words if w in answer.lower())\n",
692
+ " r_task += min(0.3, 0.06 * matches)\n",
693
+ " data_refs = len(re.findall(r\"\\d+%|R\\$\\s*\\d|média|percentual|comparad|taxa\", answer, re.IGNORECASE))\n",
694
+ " r_task += min(0.2, 0.04 * data_refs)\n",
695
+ "\n",
696
+ " return min(r_format + r_partial + r_task, 1.0)\n",
697
+ "\n",
698
+ "\n",
699
+ "def reward_push(completion: str) -> float:\n",
700
+ " \"\"\"Staged reward for push notifications (max 1.0).\"\"\"\n",
701
+ " answer = strip_think(completion)\n",
702
+ " if not answer:\n",
703
+ " return 0.0\n",
704
+ "\n",
705
+ " # ── Stage 1: FORMAT (max 0.2)\n",
706
+ " r_format = 0.0\n",
707
+ " if has_think_block(completion):\n",
708
+ " r_format += 0.05\n",
709
+ " length = len(answer)\n",
710
+ " if length <= 160:\n",
711
+ " r_format += 0.15\n",
712
+ " elif length <= 300:\n",
713
+ " r_format += 0.1\n",
714
+ " else:\n",
715
+ " r_format += 0.05\n",
716
+ "\n",
717
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
718
+ " r_partial = 0.0\n",
719
+ " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua\", answer, re.IGNORECASE)\n",
720
+ " r_partial += min(0.15, 0.02 * len(pt_markers))\n",
721
+ " if re.search(r\"[!?]|[\\U0001F600-\\U0001F64F]|[\\U0001F300-\\U0001F5FF]\", answer):\n",
722
+ " r_partial += 0.05\n",
723
+ " if len(answer.split()) >= 5:\n",
724
+ " r_partial += 0.1\n",
725
+ "\n",
726
+ " # ── Stage 3: TASK (max 0.5)\n",
727
+ " r_task = 0.0\n",
728
+ " if length <= 120:\n",
729
+ " r_task += 0.25\n",
730
+ " else:\n",
731
+ " r_task += 0.25 * max(0, 1 - (length - 120) / 120)\n",
732
+ " generic_phrases = [\n",
733
+ " \"olá! como podemos ajudar\", \"obrigado pela sua compra\",\n",
734
+ " \"seu pedido foi confirmado\", \"agradecemos sua preferência\",\n",
735
+ " ]\n",
736
+ " max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)\n",
737
+ " r_task += 0.25 * (1 - max_similarity)\n",
738
+ "\n",
739
+ " return min(r_format + r_partial + r_task, 1.0)\n",
740
+ "\n",
741
+ "\n",
742
+ "def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n",
743
+ " \"\"\"\n",
744
+ " Master reward function v3: dispatches by task type + zero-advantage noise.\n",
745
+ " \"\"\"\n",
746
+ " rewards = []\n",
747
+ " for completion, prompt in zip(completions, prompts):\n",
748
+ " if isinstance(completion, list):\n",
749
+ " comp_text = completion[-1][\"content\"] if completion else \"\"\n",
750
+ " else:\n",
751
+ " comp_text = str(completion)\n",
752
+ "\n",
753
+ " if isinstance(prompt, list):\n",
754
+ " prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
755
+ " else:\n",
756
+ " prompt_text = str(prompt)\n",
757
+ "\n",
758
+ " task = _classify_task_type(prompt_text)\n",
759
+ "\n",
760
+ " if task == \"extraction\":\n",
761
+ " rewards.append(reward_extraction(comp_text))\n",
762
+ " elif task == \"sql_qa\":\n",
763
+ " rewards.append(reward_sql_qa(comp_text))\n",
764
+ " elif task == \"insights\":\n",
765
+ " rewards.append(reward_insights(comp_text))\n",
766
+ " elif task == \"push\":\n",
767
+ " rewards.append(reward_push(comp_text))\n",
768
+ " else:\n",
769
+ " r = 0.15 if has_think_block(comp_text) else 0.0\n",
770
+ " r += 0.2 if comp_text.strip() else 0.0\n",
771
+ " rewards.append(r)\n",
772
+ "\n",
773
+ " # ── v3: Zero-advantage noise injection ────────────────────────────────────\n",
774
+ " if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:\n",
775
+ " for i in range(0, len(rewards), NUM_GENERATIONS):\n",
776
+ " group = rewards[i:i+NUM_GENERATIONS]\n",
777
+ " if len(group) < 2:\n",
778
+ " continue\n",
779
+ " if max(group) - min(group) < 0.001:\n",
780
+ " for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):\n",
781
+ " rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)\n",
782
+ "\n",
783
+ " return rewards\n",
784
+ "\n",
785
+ "\n",
786
+ "print(\"✓ v3 Reward functions defined (staged: format → partial → task)\")"
787
+ ]
788
+ },
789
+ {
790
+ "cell_type": "markdown",
791
+ "metadata": {},
792
+ "source": [
793
+ "---\n",
794
+ "\n",
795
+ "## Cell 7: Reward Calibration\n",
796
+ "\n",
797
+ "**Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38)."
798
+ ]
799
+ },
800
+ {
801
+ "cell_type": "code",
802
+ "execution_count": null,
803
+ "metadata": {},
804
+ "outputs": [],
805
+ "source": [
806
+ "train_path = DATA_DIR / \"pairs\" / \"train.jsonl\"\n",
807
+ "\n",
808
+ "by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
809
+ "with open(train_path) as f:\n",
810
+ " for line in f:\n",
811
+ " row = json.loads(line)\n",
812
+ " convs = row[\"conversations\"]\n",
813
+ " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
814
+ " if not prompt_msgs:\n",
815
+ " continue\n",
816
+ " user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
817
+ " task = _classify_task_type(user_text)\n",
818
+ " by_type[task].append(prompt_msgs)\n",
819
+ "\n",
820
+ "print(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n",
821
+ "\n",
822
+ "rng = random.Random(42)\n",
823
+ "cal_samples = []\n",
824
+ "for task_type in [\"extraction\", \"extraction\", \"sql_qa\", \"sql_qa\", \"insights\", \"insights\", \"push\", \"push\"]:\n",
825
+ " cal_samples.append(rng.choice(by_type[task_type]))\n",
826
+ "\n",
827
+ "FastLanguageModel.for_inference(model)\n",
828
+ "print(f\"\\nReward calibration v3 ({len(cal_samples)} samples):\")\n",
829
+ "print(\"-\" * 70)\n",
830
+ "\n",
831
+ "cal_rewards = []\n",
832
+ "for i, msgs in enumerate(cal_samples):\n",
833
+ " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
834
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
835
+ " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
836
+ " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
837
+ " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
838
+ "\n",
839
+ " r = commerce_reward_fn([response], [text])[0]\n",
840
+ " cal_rewards.append(r)\n",
841
+ " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
842
+ " has_answer = \"</think>\" in response\n",
843
+ " answer_preview = strip_think(response)[:100] if has_answer else \"[stuck in <think>]\"\n",
844
+ " task = _classify_task_type(text)\n",
845
+ " print(f\" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'⚠️ HIT' if hit_ceiling else 'ok':6s} | {answer_preview}\")\n",
846
+ "\n",
847
+ "print(f\"\\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}\")\n",
848
+ "print(f\"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70\")\n",
849
+ "print(f\"Variance > 0: {len(set(cal_rewards)) > 1}\")"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "markdown",
854
+ "metadata": {},
855
+ "source": [
856
+ "---\n",
857
+ "\n",
858
+ "## Cell 8: Dataset Preparation v3"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": null,
864
+ "metadata": {},
865
+ "outputs": [],
866
+ "source": [
867
+ "from datasets import Dataset\n",
868
+ "\n",
869
+ "def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,\n",
870
+ " general_mix=GENERAL_MIX_RATIO, seed=42):\n",
871
+ " rng = random.Random(seed)\n",
872
+ "\n",
873
+ " train_pools = {}\n",
874
+ " eval_records = []\n",
875
+ " for task, pool in by_type.items():\n",
876
+ " shuffled = pool.copy()\n",
877
+ " rng.shuffle(shuffled)\n",
878
+ " n_eval = max(1, int(len(shuffled) * eval_ratio))\n",
879
+ " eval_records.extend(shuffled[:n_eval])\n",
880
+ " train_pools[task] = shuffled[n_eval:]\n",
881
+ "\n",
882
+ " if n_prompts is None:\n",
883
+ " train_records = []\n",
884
+ " for task, pool in train_pools.items():\n",
885
+ " train_records.extend(pool)\n",
886
+ " rng.shuffle(train_records)\n",
887
+ " else:\n",
888
+ " targets = {\n",
889
+ " \"extraction\": int(n_prompts * 0.4),\n",
890
+ " \"sql_qa\": int(n_prompts * 0.4),\n",
891
+ " \"insights\": int(n_prompts * 0.1),\n",
892
+ " \"push\": int(n_prompts * 0.1),\n",
893
+ " }\n",
894
+ " train_records = []\n",
895
+ " for task, target_n in targets.items():\n",
896
+ " pool = train_pools[task]\n",
897
+ " n = min(target_n, len(pool))\n",
898
+ " train_records.extend(rng.sample(pool, n))\n",
899
+ " rng.shuffle(train_records)\n",
900
+ "\n",
901
+ " general_path = DATA_DIR / \"pairs\" / \"general_reasoning.jsonl\"\n",
902
+ " if general_mix > 0 and general_path.exists():\n",
903
+ " general_records = []\n",
904
+ " with open(general_path) as f:\n",
905
+ " for line in f:\n",
906
+ " row = json.loads(line)\n",
907
+ " convs = row[\"conversations\"]\n",
908
+ " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
909
+ " if prompt_msgs:\n",
910
+ " general_records.append(prompt_msgs)\n",
911
+ " n_general = int(len(train_records) * general_mix / (1 - general_mix))\n",
912
+ " n_general = min(n_general, len(general_records))\n",
913
+ " if n_general > 0:\n",
914
+ " train_records.extend(rng.sample(general_records, n_general))\n",
915
+ " rng.shuffle(train_records)\n",
916
+ " print(f\" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)\")\n",
917
+ " elif general_mix > 0:\n",
918
+ " print(f\" ⚠️ general_reasoning.jsonl not found — skipping mix\")\n",
919
+ "\n",
920
+ " task_dist = {}\n",
921
+ " for record in train_records:\n",
922
+ " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
923
+ " task = _classify_task_type(user_text)\n",
924
+ " task_dist[task] = task_dist.get(task, 0) + 1\n",
925
+ "\n",
926
+ " n_domain = len(train_records)\n",
927
+ " steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)\n",
928
+ "\n",
929
+ " print(f\"v3 Dataset split (eval_ratio={eval_ratio}):\")\n",
930
+ " print(f\" train : {n_domain} prompts\")\n",
931
+ " print(f\" eval : {len(eval_records)} prompts\")\n",
932
+ " print(f\" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}\")\n",
933
+ " print(f\" steps/epoch: {n_domain} × {NUM_GENERATIONS} / ({BATCH_SIZE} × {GRAD_ACCUM}) = {steps_per_epoch}\")\n",
934
+ " print(f\" MAX_STEPS={MAX_STEPS} → {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}\")\n",
935
+ "\n",
936
+ " train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n",
937
+ " eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n",
938
+ " return train_ds, eval_ds\n",
939
+ "\n",
940
+ "\n",
941
+ "train_dataset, eval_dataset = prepare_grpo_datasets_v3()\n",
942
+ "dataset = train_dataset\n",
943
+ "print(f\"\\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}\")"
944
+ ]
945
+ },
946
+ {
947
+ "cell_type": "markdown",
948
+ "metadata": {},
949
+ "source": [
950
+ "---\n",
951
+ "\n",
952
+ "## Cell 9: Smoke Test\n",
953
+ "\n",
954
+ "**Gate:** Runs 1 step without OOM at new completion length (4096)."
955
+ ]
956
+ },
957
+ {
958
+ "cell_type": "code",
959
+ "execution_count": null,
960
+ "metadata": {},
961
+ "outputs": [],
962
+ "source": [
963
+ "from trl import GRPOConfig, GRPOTrainer\n",
964
+ "\n",
965
+ "FastLanguageModel.for_training(model)\n",
966
+ "\n",
967
+ "smoke_config = GRPOConfig(\n",
968
+ " output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n",
969
+ " num_generations=NUM_GENERATIONS,\n",
970
+ " scale_rewards=SCALE_REWARDS,\n",
971
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
972
+ " max_steps=1,\n",
973
+ " num_train_epochs=1,\n",
974
+ " temperature=TEMPERATURE,\n",
975
+ " per_device_train_batch_size=BATCH_SIZE,\n",
976
+ " gradient_accumulation_steps=1,\n",
977
+ " learning_rate=LEARNING_RATE,\n",
978
+ " fp16=False,\n",
979
+ " bf16=True,\n",
980
+ " logging_steps=1,\n",
981
+ " save_steps=999,\n",
982
+ " report_to=\"none\",\n",
983
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
984
+ " seed=42,\n",
985
+ " remove_unused_columns=False,\n",
986
+ ")\n",
987
+ "\n",
988
+ "smoke_trainer = GRPOTrainer(\n",
989
+ " model=model,\n",
990
+ " reward_funcs=commerce_reward_fn,\n",
991
+ " args=smoke_config,\n",
992
+ " train_dataset=dataset,\n",
993
+ " tokenizer=tokenizer,\n",
994
+ ")\n",
995
+ "\n",
996
+ "t0 = time.time()\n",
997
+ "smoke_trainer.train()\n",
998
+ "step_time = time.time() - t0\n",
999
+ "\n",
1000
+ "print(f\"\\n✓ Smoke test passed!\")\n",
1001
+ "print(f\" Step time (grad_accum=1): {step_time:.0f}s\")\n",
1002
+ "print(f\" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s\")\n",
1003
+ "print(f\" VRAM peak: {torch.cuda.max_memory_allocated()/1e9:.1f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB\")\n",
1004
+ "\n",
1005
+ "vram_used = torch.cuda.max_memory_allocated() / 1e9\n",
1006
+ "vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
1007
+ "if vram_used > vram_total * 0.95:\n",
1008
+ " print(f\"\\n⚠️ VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM\")\n",
1009
+ " print(f\" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072\")\n",
1010
+ " print(f\" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)\")\n",
1011
+ "\n",
1012
+ "del smoke_trainer\n",
1013
+ "gc.collect(); torch.cuda.empty_cache()"
1014
+ ]
1015
+ },
1016
+ {
1017
+ "cell_type": "markdown",
1018
+ "metadata": {},
1019
+ "source": [
1020
+ "---\n",
1021
+ "\n",
1022
+ "## Cell 10: Probe Run (3 steps)"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "execution_count": null,
1028
+ "metadata": {},
1029
+ "outputs": [],
1030
+ "source": [
1031
+ "FastLanguageModel.for_training(model)\n",
1032
+ "\n",
1033
+ "probe_config = GRPOConfig(\n",
1034
+ " output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
1035
+ " num_generations=NUM_GENERATIONS,\n",
1036
+ " scale_rewards=SCALE_REWARDS,\n",
1037
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1038
+ " max_steps=3,\n",
1039
+ " temperature=TEMPERATURE,\n",
1040
+ " num_train_epochs=NUM_EPOCHS,\n",
1041
+ " per_device_train_batch_size=BATCH_SIZE,\n",
1042
+ " gradient_accumulation_steps=GRAD_ACCUM,\n",
1043
+ " learning_rate=LEARNING_RATE,\n",
1044
+ " warmup_ratio=0.1,\n",
1045
+ " lr_scheduler_type=\"cosine\",\n",
1046
+ " fp16=False,\n",
1047
+ " bf16=True,\n",
1048
+ " logging_steps=1,\n",
1049
+ " disable_tqdm=True,\n",
1050
+ " logging_first_step=True,\n",
1051
+ " save_steps=999,\n",
1052
+ " report_to=\"none\",\n",
1053
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1054
+ " seed=42,\n",
1055
+ " remove_unused_columns=False,\n",
1056
+ ")\n",
1057
+ "\n",
1058
+ "probe_trainer = GRPOTrainer(\n",
1059
+ " model=model,\n",
1060
+ " reward_funcs=commerce_reward_fn,\n",
1061
+ " args=probe_config,\n",
1062
+ " train_dataset=dataset,\n",
1063
+ " tokenizer=tokenizer,\n",
1064
+ ")\n",
1065
+ "\n",
1066
+ "t0 = time.time()\n",
1067
+ "result = probe_trainer.train()\n",
1068
+ "elapsed = time.time() - t0\n",
1069
+ "\n",
1070
+ "print(f\"\\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)\")\n",
1071
+ "print(f\" Train loss: {result.training_loss:.6f}\")\n",
1072
+ "print(f\" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h\")\n",
1073
+ "\n",
1074
+ "if abs(result.training_loss) < 1e-6:\n",
1075
+ " print(\" ⚠️ Loss is near-zero — reward variance may be insufficient\")\n",
1076
+ "else:\n",
1077
+ " print(\" ✓ Loss is non-zero — GRPO has gradient signal\")\n",
1078
+ "\n",
1079
+ "del probe_trainer\n",
1080
+ "gc.collect(); torch.cuda.empty_cache()"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "cell_type": "markdown",
1085
+ "metadata": {},
1086
+ "source": [
1087
+ "---\n",
1088
+ "\n",
1089
+ "## Cell 11: Full Training Run v3"
1090
+ ]
1091
+ },
1092
+ {
1093
+ "cell_type": "code",
1094
+ "execution_count": null,
1095
+ "metadata": {},
1096
+ "outputs": [],
1097
+ "source": [
1098
+ "import wandb\n",
1099
+ "\n",
1100
+ "_wandb_key = os.environ.get(\"WANDB_API_KEY\", \"\").strip()\n",
1101
+ "if not _wandb_key:\n",
1102
+ " raise EnvironmentError(\"WANDB_API_KEY is not set.\")\n",
1103
+ "wandb.login(key=_wandb_key, relogin=True)\n",
1104
+ "print(f\"✓ W&B authenticated\")"
1105
+ ]
1106
+ },
1107
+ {
1108
+ "cell_type": "code",
1109
+ "execution_count": null,
1110
+ "metadata": {},
1111
+ "outputs": [],
1112
+ "source": [
1113
+ "import shutil\n",
1114
+ "import torch\n",
1115
+ "from transformers import TrainerCallback\n",
1116
+ "from trl import GRPOConfig, GRPOTrainer\n",
1117
+ "\n",
1118
+ "wandb.init(\n",
1119
+ " project=WANDB_PROJECT,\n",
1120
+ " name=f\"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}\",\n",
1121
+ " config={\n",
1122
+ " \"model_id\": MODEL_ID,\n",
1123
+ " \"version\": \"v3\",\n",
1124
+ " \"temperature\": TEMPERATURE,\n",
1125
+ " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
1126
+ " \"num_generations\": NUM_GENERATIONS,\n",
1127
+ " \"learning_rate\": LEARNING_RATE,\n",
1128
+ " \"beta\": BETA,\n",
1129
+ " \"batch_size\": BATCH_SIZE,\n",
1130
+ " \"grad_accum\": GRAD_ACCUM,\n",
1131
+ " \"max_steps\": MAX_STEPS,\n",
1132
+ " \"scale_rewards\": SCALE_REWARDS,\n",
1133
+ " \"save_steps\": SAVE_STEPS,\n",
1134
+ " \"eval_steps\": EVAL_STEPS,\n",
1135
+ " \"eval_max_samples\": EVAL_MAX_SAMPLES,\n",
1136
+ " \"eval_max_tokens\": EVAL_MAX_TOKENS,\n",
1137
+ " \"eval_temperature\": EVAL_TEMPERATURE,\n",
1138
+ " \"patience\": EARLY_STOPPING_PATIENCE,\n",
1139
+ " \"delta\": EARLY_STOPPING_DELTA,\n",
1140
+ " \"train_prompts\": len(train_dataset),\n",
1141
+ " \"eval_prompts\": len(eval_dataset),\n",
1142
+ " \"zero_adv_noise_std\": ZERO_ADV_NOISE_STD,\n",
1143
+ " \"general_mix_ratio\": GENERAL_MIX_RATIO,\n",
1144
+ " \"_ref_temperature\": \"Skywork-OR1 (2505.22312)\",\n",
1145
+ " \"_ref_completion_length\": \"Dr. GRPO (2503.20783)\",\n",
1146
+ " \"_ref_staged_rewards\": \"Reasoning-SQL (2503.23157)\",\n",
1147
+ " \"_ref_zero_adv\": \"Skywork-OR1 (2505.22312)\",\n",
1148
+ " },\n",
1149
+ ")\n",
1150
+ "print(f\"✓ W&B run: {wandb.run.url}\")\n",
1151
+ "\n",
1152
+ "FRESH = True\n",
1153
+ "resume_from = None\n",
1154
+ "if FRESH and CHECKPOINT_DIR.exists():\n",
1155
+ " print(\"FRESH: deleting old checkpoints...\")\n",
1156
+ " shutil.rmtree(CHECKPOINT_DIR)\n",
1157
+ "elif CHECKPOINT_DIR.exists():\n",
1158
+ " checkpoints = sorted(\n",
1159
+ " [d for d in CHECKPOINT_DIR.iterdir()\n",
1160
+ " if d.is_dir() and d.name.startswith(\"checkpoint-\")],\n",
1161
+ " key=lambda d: int(d.name.split(\"-\")[-1]),\n",
1162
+ " )\n",
1163
+ " if checkpoints:\n",
1164
+ " resume_from = str(checkpoints[-1])\n",
1165
+ " print(f\"Resuming from: {resume_from}\")\n",
1166
+ "\n",
1167
+ "\n",
1168
+ "class UnslothGRPOTrainer(GRPOTrainer):\n",
1169
+ " \"\"\"Wraps generation with Unsloth for_inference()/for_training().\"\"\"\n",
1170
+ " def _generate(self, prompts, images):\n",
1171
+ " FastLanguageModel.for_inference(self.model)\n",
1172
+ " try:\n",
1173
+ " result = super()._generate(prompts, images)\n",
1174
+ " finally:\n",
1175
+ " FastLanguageModel.for_training(self.model)\n",
1176
+ " return result\n",
1177
+ "\n",
1178
+ "\n",
1179
+ "class EvalRewardCallback(TrainerCallback):\n",
1180
+ " \"\"\"v3: deterministic eval, per-task breakdown, patience=15.\"\"\"\n",
1181
+ " def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,\n",
1182
+ " delta=EARLY_STOPPING_DELTA):\n",
1183
+ " self.eval_records = eval_records\n",
1184
+ " self.reward_fn = reward_fn\n",
1185
+ " self.patience = patience\n",
1186
+ " self.delta = delta\n",
1187
+ " self.best_reward = -float(\"inf\")\n",
1188
+ " self.no_improve_count = 0\n",
1189
+ "\n",
1190
+ " def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n",
1191
+ " if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n",
1192
+ " return control\n",
1193
+ " tokenizer = processing_class\n",
1194
+ " if tokenizer is None:\n",
1195
+ " print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n",
1196
+ " return control\n",
1197
+ "\n",
1198
+ " mean_reward, task_rewards = self._run_eval(model, tokenizer, args)\n",
1199
+ " improved = mean_reward > self.best_reward + self.delta\n",
1200
+ " status = \"↑ improved\" if improved else f\"↔ no gain ({self.no_improve_count + 1}/{self.patience})\"\n",
1201
+ "\n",
1202
+ " log_dict = {\n",
1203
+ " \"eval/mean_reward\": mean_reward,\n",
1204
+ " \"eval/best_reward\": max(self.best_reward, mean_reward),\n",
1205
+ " \"eval/no_improve_count\": self.no_improve_count,\n",
1206
+ " }\n",
1207
+ " for task, rewards in task_rewards.items():\n",
1208
+ " if rewards:\n",
1209
+ " log_dict[f\"eval/{task}_reward\"] = sum(rewards) / len(rewards)\n",
1210
+ " wandb.log(log_dict, step=state.global_step)\n",
1211
+ "\n",
1212
+ " print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n",
1213
+ " for task, rewards in task_rewards.items():\n",
1214
+ " if rewards:\n",
1215
+ " print(f\" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})\")\n",
1216
+ "\n",
1217
+ " if improved:\n",
1218
+ " self.best_reward = mean_reward\n",
1219
+ " self.no_improve_count = 0\n",
1220
+ " else:\n",
1221
+ " self.no_improve_count += 1\n",
1222
+ " if self.no_improve_count >= self.patience:\n",
1223
+ " print(f\"[EarlyStopping] No improvement ≥ {self.delta} for {self.patience} consecutive evals. Halting.\")\n",
1224
+ " wandb.log({\"early_stop/step\": state.global_step}, step=state.global_step)\n",
1225
+ " control.should_training_stop = True\n",
1226
+ " return control\n",
1227
+ "\n",
1228
+ " def _run_eval(self, model, tokenizer, args):\n",
1229
+ " FastLanguageModel.for_inference(model)\n",
1230
+ " rewards = []\n",
1231
+ " task_rewards = {}\n",
1232
+ " subset = self.eval_records[:EVAL_MAX_SAMPLES]\n",
1233
+ " for record in subset:\n",
1234
+ " msgs = record[\"prompt\"]\n",
1235
+ " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
1236
+ " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True,\n",
1237
+ " max_length=args.max_prompt_length).to(model.device)\n",
1238
+ " with torch.no_grad():\n",
1239
+ " out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,\n",
1240
+ " temperature=EVAL_TEMPERATURE, do_sample=True)\n",
1241
+ " resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1242
+ " r = self.reward_fn([resp], [text])[0]\n",
1243
+ " rewards.append(r)\n",
1244
+ " user_text = \" \".join(m.get(\"content\", \"\") for m in msgs if m.get(\"role\") == \"user\")\n",
1245
+ " task = _classify_task_type(user_text)\n",
1246
+ " task_rewards.setdefault(task, []).append(r)\n",
1247
+ " FastLanguageModel.for_training(model)\n",
1248
+ " mean = sum(rewards) / len(rewards) if rewards else 0.0\n",
1249
+ " return mean, task_rewards\n",
1250
+ "\n",
1251
+ "\n",
1252
+ "class EntropyMonitorCallback(TrainerCallback):\n",
1253
+ " \"\"\"v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 §4).\"\"\"\n",
1254
+ " def __init__(self):\n",
1255
+ " self.consecutive_ceiling_hits = 0\n",
1256
+ "\n",
1257
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
1258
+ " if not logs:\n",
1259
+ " return\n",
1260
+ " step = state.global_step\n",
1261
+ " monitor = {}\n",
1262
+ " comp_len = logs.get(\"completion_length\", 0)\n",
1263
+ " if comp_len > 0:\n",
1264
+ " ratio = comp_len / MAX_COMPLETION_LENGTH\n",
1265
+ " monitor[\"monitor/completion_ratio\"] = ratio\n",
1266
+ " if ratio > 0.95:\n",
1267
+ " self.consecutive_ceiling_hits += 1\n",
1268
+ " if self.consecutive_ceiling_hits >= 3:\n",
1269
+ " print(f\"⚠️ Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.\")\n",
1270
+ " else:\n",
1271
+ " self.consecutive_ceiling_hits = 0\n",
1272
+ " reward_std = logs.get(\"reward_std\", logs.get(\"rewards/commerce_reward_fn/std\", 0))\n",
1273
+ " if reward_std is not None:\n",
1274
+ " monitor[\"monitor/reward_std\"] = reward_std\n",
1275
+ " if reward_std < 0.01:\n",
1276
+ " print(f\"⚠️ Step {step}: reward_std={reward_std:.4f} — near-zero variance\")\n",
1277
+ " clip_high = logs.get(\"clip_ratio/high_mean\", 0)\n",
1278
+ " clip_low = logs.get(\"clip_ratio/low_mean\", 0)\n",
1279
+ " if clip_high is not None and clip_low is not None:\n",
1280
+ " total_clip = clip_high + abs(clip_low)\n",
1281
+ " monitor[\"monitor/total_clip_ratio\"] = total_clip\n",
1282
+ " if total_clip > 0.01 and step > 10:\n",
1283
+ " print(f\"✓ Step {step}: clip_ratio={total_clip:.3f} — policy is updating\")\n",
1284
+ " if monitor and wandb.run:\n",
1285
+ " wandb.log(monitor, step=step)\n",
1286
+ "\n",
1287
+ "\n",
1288
+ "FastLanguageModel.for_training(model)\n",
1289
+ "\n",
1290
+ "grpo_config = GRPOConfig(\n",
1291
+ " output_dir=str(CHECKPOINT_DIR),\n",
1292
+ " num_generations=NUM_GENERATIONS,\n",
1293
+ " scale_rewards=SCALE_REWARDS,\n",
1294
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1295
+ " temperature=TEMPERATURE,\n",
1296
+ " max_steps=MAX_STEPS,\n",
1297
+ " num_train_epochs=NUM_EPOCHS,\n",
1298
+ " per_device_train_batch_size=BATCH_SIZE,\n",
1299
+ " gradient_accumulation_steps=GRAD_ACCUM,\n",
1300
+ " learning_rate=LEARNING_RATE,\n",
1301
+ " warmup_ratio=0.1,\n",
1302
+ " lr_scheduler_type=\"cosine\",\n",
1303
+ " fp16=False,\n",
1304
+ " bf16=True,\n",
1305
+ " logging_steps=1,\n",
1306
+ " logging_first_step=True,\n",
1307
+ " disable_tqdm=True,\n",
1308
+ " save_steps=SAVE_STEPS,\n",
1309
+ " save_total_limit=SAVE_TOTAL_LIMIT,\n",
1310
+ " save_only_model=True,\n",
1311
+ " eval_steps=EVAL_STEPS,\n",
1312
+ " report_to=\"wandb\",\n",
1313
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1314
+ " seed=42,\n",
1315
+ " remove_unused_columns=False,\n",
1316
+ " **({\"use_vllm\": True, \"vllm_mode\": \"colocate\",\n",
1317
+ " \"vllm_enable_sleep_mode\": True} if USE_VLLM else {}),\n",
1318
+ ")\n",
1319
+ "\n",
1320
+ "eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)\n",
1321
+ "entropy_cb = EntropyMonitorCallback()\n",
1322
+ "\n",
1323
+ "TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer\n",
1324
+ "trainer = TrainerClass(\n",
1325
+ " model=model,\n",
1326
+ " reward_funcs=commerce_reward_fn,\n",
1327
+ " args=grpo_config,\n",
1328
+ " train_dataset=train_dataset,\n",
1329
+ " processing_class=tokenizer,\n",
1330
+ " callbacks=[eval_cb, entropy_cb],\n",
1331
+ ")\n",
1332
+ "\n",
1333
+ "print(f\"{'='*70}\")\n",
1334
+ "print(f\"GRPO v3 Training — Ready to Launch\")\n",
1335
+ "print(f\"{'='*70}\")\n",
1336
+ "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1337
+ "print(f\" Max steps: {MAX_STEPS}\")\n",
1338
+ "print(f\" Temperature: {TEMPERATURE} (v2 was 0.8)\")\n",
1339
+ "print(f\" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)\")\n",
1340
+ "print(f\" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)\")\n",
1341
+ "print(f\" Learning rate: {LEARNING_RATE} (v2 was 5e-7)\")\n",
1342
+ "print(f\" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})\")\n",
1343
+ "print(f\" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples × {EVAL_MAX_TOKENS} tok)\")\n",
1344
+ "print(f\" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)\")\n",
1345
+ "print(f\" Resume: {resume_from is not None}\")\n",
1346
+ "print(f\"{'='*70}\")\n",
1347
+ "\n",
1348
+ "t_start = time.time()\n",
1349
+ "result = trainer.train(resume_from_checkpoint=resume_from)\n",
1350
+ "elapsed = time.time() - t_start\n",
1351
+ "\n",
1352
+ "wandb.log({\n",
1353
+ " \"train/final_loss\": result.training_loss,\n",
1354
+ " \"train/duration_hours\": elapsed / 3600,\n",
1355
+ " \"train/total_steps\": result.global_step,\n",
1356
+ " \"eval/best_reward_final\": eval_cb.best_reward,\n",
1357
+ "})\n",
1358
+ "wandb.finish()\n",
1359
+ "\n",
1360
+ "print(f\"\\n{'='*70}\")\n",
1361
+ "print(f\"GRPO v3 Training Complete\")\n",
1362
+ "print(f\" Loss: {result.training_loss:.6f}\")\n",
1363
+ "print(f\" Steps: {result.global_step}\")\n",
1364
+ "print(f\" Duration: {elapsed/3600:.1f}h\")\n",
1365
+ "print(f\" Best eval R: {eval_cb.best_reward:.4f}\")\n",
1366
+ "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1367
+ "print(f\"{'='*70}\")"
1368
+ ]
1369
+ },
1370
+ {
1371
+ "cell_type": "markdown",
1372
+ "metadata": {},
1373
+ "source": [
1374
+ "---\n",
1375
+ "\n",
1376
+ "## Cell 12: Save Adapter"
1377
+ ]
1378
+ },
1379
+ {
1380
+ "cell_type": "code",
1381
+ "execution_count": null,
1382
+ "metadata": {},
1383
+ "outputs": [],
1384
+ "source": [
1385
+ "GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\n",
1386
+ "model.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1387
+ "tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1388
+ "\n",
1389
+ "summary = {\n",
1390
+ " \"model_id\": MODEL_ID,\n",
1391
+ " \"sft_adapter\": str(SFT_ADAPTER_DIR),\n",
1392
+ " \"method\": \"GRPO\",\n",
1393
+ " \"version\": \"v3\",\n",
1394
+ " \"train_loss\": result.training_loss,\n",
1395
+ " \"best_eval_reward\": eval_cb.best_reward,\n",
1396
+ " \"num_prompts\": len(train_dataset),\n",
1397
+ " \"num_generations\": NUM_GENERATIONS,\n",
1398
+ " \"scale_rewards\": SCALE_REWARDS,\n",
1399
+ " \"temperature\": TEMPERATURE,\n",
1400
+ " \"learning_rate\": LEARNING_RATE,\n",
1401
+ " \"beta\": BETA,\n",
1402
+ " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
1403
+ " \"max_steps\": MAX_STEPS,\n",
1404
+ " \"actual_steps\": result.global_step,\n",
1405
+ " \"epochs\": NUM_EPOCHS,\n",
1406
+ " \"max_seq_length\": MAX_SEQ_LENGTH,\n",
1407
+ " \"duration_seconds\": round(elapsed),\n",
1408
+ " \"gpu\": \"L4\",\n",
1409
+ " \"platform\": \"vertex-ai-workbench\",\n",
1410
+ " \"v3_fixes\": [\n",
1411
+ " \"temperature=1.0 (Skywork-OR1)\",\n",
1412
+ " \"max_completion_length=4096 (Dr. GRPO)\",\n",
1413
+ " \"learning_rate=2e-6 (4x v2)\",\n",
1414
+ " \"beta=0.0 (Dr. GRPO)\",\n",
1415
+ " \"staged rewards (Reasoning-SQL)\",\n",
1416
+ " \"zero-advantage noise (Skywork-OR1)\",\n",
1417
+ " \"entropy monitoring callback\",\n",
1418
+ " ],\n",
1419
+ "}\n",
1420
+ "with open(GRPO_ADAPTER_DIR / \"training_summary.json\", \"w\") as f:\n",
1421
+ " json.dump(summary, f, indent=2)\n",
1422
+ "\n",
1423
+ "print(f\"✓ Adapter saved to {GRPO_ADAPTER_DIR}\")\n",
1424
+ "print(f\" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}\")"
1425
+ ]
1426
+ },
1427
+ {
1428
+ "cell_type": "markdown",
1429
+ "metadata": {},
1430
+ "source": [
1431
+ "---\n",
1432
+ "\n",
1433
+ "## Cell 13: Validation"
1434
+ ]
1435
+ },
1436
+ {
1437
+ "cell_type": "code",
1438
+ "execution_count": null,
1439
+ "metadata": {},
1440
+ "outputs": [],
1441
+ "source": [
1442
+ "FastLanguageModel.for_inference(model)\n",
1443
+ "\n",
1444
+ "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1445
+ "\n",
1446
+ "test_prompts = [\n",
1447
+ " {\"role\": \"user\", \"content\": (\n",
1448
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1449
+ " \"nota=2/5 | status=delivered\\ntítulo: decepcionado\\n\"\n",
1450
+ " \"texto: Produto veio com defeito e o vendedor não respondeu.\\n\\n\"\n",
1451
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1452
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1453
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1454
+ " )},\n",
1455
+ " {\"role\": \"user\", \"content\": (\n",
1456
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1457
+ " \"nota=5/5 | status=delivered\\ntítulo: adorei!\\n\"\n",
1458
+ " \"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1459
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1460
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1461
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1462
+ " )},\n",
1463
+ " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
1464
+ " {\"role\": \"user\", \"content\": \"Analise a retenção de clientes afetados por product_quality.\"},\n",
1465
+ " {\"role\": \"user\", \"content\": (\n",
1466
+ " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1467
+ " \"- Reclamação: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1468
+ " \"Este cliente deve receber uma notificação de reengajamento?\"\n",
1469
+ " )},\n",
1470
+ " {\"role\": \"user\", \"content\": \"Compare a satisfação de clientes em SP vs RJ.\"},\n",
1471
+ " {\"role\": \"user\", \"content\": (\n",
1472
+ " \"Crie uma notificação push de reengajamento para um cliente em SP \"\n",
1473
+ " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1474
+ " )},\n",
1475
+ "]\n",
1476
+ "\n",
1477
+ "print(\"=\" * 70)\n",
1478
+ "print(\"GRPO v3 Validation\")\n",
1479
+ "print(\"=\" * 70)\n",
1480
+ "\n",
1481
+ "v3_rewards = []\n",
1482
+ "for i, prompt in enumerate(test_prompts):\n",
1483
+ " messages = [system_msg, prompt]\n",
1484
+ " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1485
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1486
+ "\n",
1487
+ " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)\n",
1488
+ " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
1489
+ " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1490
+ "\n",
1491
+ " reward = commerce_reward_fn([response], [text])[0]\n",
1492
+ " v3_rewards.append(reward)\n",
1493
+ " answer = strip_think(response)\n",
1494
+ " task = _classify_task_type(prompt[\"content\"])\n",
1495
+ " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1496
+ "\n",
1497
+ " print(f\"\\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---\")\n",
1498
+ " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1499
+ " print(f\"Answer: {answer[:400]}\")\n",
1500
+ "\n",
1501
+ "print(f\"\\n{'='*70}\")\n",
1502
+ "print(f\"v3 Validation Summary\")\n",
1503
+ "print(f\"{'='*70}\")\n",
1504
+ "print(f\" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1505
+ "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1506
+ "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1507
+ "print()\n",
1508
+ "print(f\" Comparison to baselines:\")\n",
1509
+ "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1510
+ "print(f\" GRPO v2 validation: mean=0.54\")\n",
1511
+ "print(f\" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1512
+ "v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100\n",
1513
+ "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")"
1514
+ ]
1515
+ }
1516
+ ]
1517
+ }