rtferraz commited on
Commit
c9b11b9
·
verified ·
1 Parent(s): a62f1dc

Upload grpo_vertex_v3.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/grpo_vertex_v3.ipynb +1674 -0
notebooks/grpo_vertex_v3.ipynb ADDED
@@ -0,0 +1,1674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)\n",
8
+ "\n",
9
+ "**v3 changes over v2 — grounded in published research:**\n",
10
+ "\n",
11
+ "| Change | v2 Value | v3 Value | Paper Reference |\n",
12
+ "|--------|----------|----------|----------------|\n",
13
+ "| Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |\n",
14
+ "| Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |\n",
15
+ "| Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |\n",
16
+ "| Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |\n",
17
+ "| β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |\n",
18
+ "| Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |\n",
19
+ "| Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |\n",
20
+ "| Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |\n",
21
+ "| Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |\n",
22
+ "| Early stopping patience | 10 | **15** | More runway for longer completions |\n",
23
+ "| Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |\n",
24
+ "| Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |\n",
25
+ "| General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |\n",
26
+ "\n",
27
+ "**Prerequisites:**\n",
28
+ "- Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`\n",
29
+ "- Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`\n",
30
+ "- **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix\n",
31
+ "\n",
32
+ "**Hardware:** L4 (24GB), PyTorch kernel, bf16 supported\n",
33
+ "\n",
34
+ "---\n",
35
+ "\n",
36
+ "## Cell 1: Dependencies\n",
37
+ "\n",
38
+ "Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# Cell 1a — Nuke everything ML-related\n",
48
+ "!pip uninstall -y torch torchvision torchaudio \\\n",
49
+ " unsloth unsloth-zoo \\\n",
50
+ " trl transformers peft accelerate \\\n",
51
+ " bitsandbytes vllm vllm-flash-attn \\\n",
52
+ " datasets tokenizers safetensors huggingface-hub \\\n",
53
+ " wandb xformers triton \\\n",
54
+ " cuda-bindings cuda-python \\\n",
55
+ " sentencepiece protobuf \\\n",
56
+ " 2>/dev/null"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# Cell 1b — Kill any stragglers\n",
66
+ "!pip freeze | grep -iE \"torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate\" | xargs pip uninstall -y 2>/dev/null"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# Cell 1c — Purge cache\n",
76
+ "!pip cache purge"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {},
82
+ "source": [
83
+ "**⚠️ Restart kernel again**, then:"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "# Cell 1d — Clean install, correct order\n",
93
+ "!pip install \"unsloth\""
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "# Cell 1e — Pin TRL (Unsloth may pull a different version)\n",
103
+ "!pip install \"trl==0.24.0\" --no-deps"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "# Cell 1f — Extra deps\n",
113
+ "!pip install \"rich\" \"wandb\""
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "---\n",
121
+ "\n",
122
+ "## Cell 2: Hello World — GPU + Unsloth Verification"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "import torch\n",
132
+ "\n",
133
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
134
+ "print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
135
+ "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n",
136
+ "print(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n",
137
+ "\n",
138
+ "from unsloth import FastLanguageModel\n",
139
+ "print(\"\\n✓ Unsloth loaded successfully\")\n",
140
+ "\n",
141
+ "import trl\n",
142
+ "print(f\"✓ TRL version: {trl.__version__}\")\n",
143
+ "\n",
144
+ "import transformers\n",
145
+ "print(f\"✓ Transformers version: {transformers.__version__}\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {},
151
+ "source": [
152
+ "---\n",
153
+ "\n",
154
+ "## Cell 3: Config + Constants"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "id": "20160257",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "import os\n",
165
+ "os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\n",
166
+ "\n",
167
+ "import json\n",
168
+ "import re\n",
169
+ "import time\n",
170
+ "import random\n",
171
+ "import gc\n",
172
+ "from pathlib import Path\n",
173
+ "\n",
174
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
175
+ "# v3 CONFIG — Every change is annotated with paper reference\n",
176
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
177
+ "\n",
178
+ "MODEL_ID = \"Polygl0t/Tucano2-qwen-3.7B-Think\"\n",
179
+ "MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt\n",
180
+ "\n",
181
+ "# ── Paths ─────────────────────────────────────────────────────────────────────\n",
182
+ "DATA_DIR = Path(\"/home/jupyter/tucano2/data\")\n",
183
+ "MODELS_DIR = Path(\"/home/jupyter/tucano2/models\")\n",
184
+ "SFT_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-sft\"\n",
185
+ "GRPO_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-grpo-v3\" # v3: separate dir from v2\n",
186
+ "CHECKPOINT_DIR = GRPO_ADAPTER_DIR / \"checkpoints\"\n",
187
+ "\n",
188
+ "# ── Training data ─────────────────────────────────────────────────────────────\n",
189
+ "GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)\n",
190
+ "GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)\n",
191
+ "\n",
192
+ "# ── Valid enums for reward scoring (unchanged from v2) ────────────────────────\n",
193
+ "VALID_SENTIMENTS = {\"positive\", \"negative\", \"neutral\"}\n",
194
+ "VALID_CATEGORIES = {\n",
195
+ " \"delivery_delay\", \"product_quality\", \"product_not_received\",\n",
196
+ " \"wrong_product\", \"seller_communication\", \"app_issue\",\n",
197
+ " \"price_value\", \"other\", \"none\",\n",
198
+ "}\n",
199
+ "VALID_CHURN = {\"low\", \"medium\", \"high\"}\n",
200
+ "VALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\n",
201
+ "EXTRACTION_FIELDS = [\n",
202
+ " \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n",
203
+ " \"product_issue\", \"seller_issue\", \"main_complaint\",\n",
204
+ " \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n",
205
+ "]\n",
206
+ "\n",
207
+ "SYSTEM_PT = (\n",
208
+ " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
209
+ " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\"\n",
210
+ ")\n",
211
+ "\n",
212
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
213
+ "# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)\n",
214
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
215
+ "\n",
216
+ "# ── Core GRPO params ──────────────────────────────────────────────────────────\n",
217
+ "BATCH_SIZE = 4\n",
218
+ "GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)\n",
219
+ " # With G=4: steps = prompts × 4 / 4 = prompts per epoch\n",
220
+ "NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions\n",
221
+ " # MC-GRPO (2601.22582): G=4 works if noise is mitigated\n",
222
+ "SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias\n",
223
+ "\n",
224
+ "# ── v3 CRITICAL FIXES ─────────────────────────────────���──────────────────────\n",
225
+ "\n",
226
+ "# FIX 1: Temperature — prevent entropy collapse\n",
227
+ "# v2 had 0.8. All published GRPO papers use 1.0.\n",
228
+ "# Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance\n",
229
+ "TEMPERATURE = 1.0\n",
230
+ "\n",
231
+ "# FIX 2: Completion length — remove the ceiling\n",
232
+ "# v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.\n",
233
+ "# Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient\n",
234
+ "MAX_COMPLETION_LENGTH = 4096\n",
235
+ "\n",
236
+ "# FIX 3: Learning rate — more aggressive\n",
237
+ "# v2: clip_ratio=0 on all steps → updates were too small to matter\n",
238
+ "# Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.\n",
239
+ "# We go 2× since v2 showed zero clipping (model can absorb stronger push)\n",
240
+ "LEARNING_RATE = 2e-6\n",
241
+ "\n",
242
+ "# FIX 4: β = 0 (no KL penalty)\n",
243
+ "# Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards\n",
244
+ "# v2 used implicit KL through default β — we explicitly disable it\n",
245
+ "BETA = 0.0\n",
246
+ "\n",
247
+ "# ── Training schedule ─────────────────────────────────────────────────────────\n",
248
+ "NUM_EPOCHS = 1\n",
249
+ "MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed\n",
250
+ " # With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch\n",
251
+ " # MAX_STEPS=500 < 1 epoch — early stopping or manual extension\n",
252
+ "\n",
253
+ "# ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────\n",
254
+ "EVAL_SPLIT_RATIO = 0.15\n",
255
+ "EVAL_STEPS = 10\n",
256
+ "EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway\n",
257
+ "EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains\n",
258
+ "SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again\n",
259
+ "SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)\n",
260
+ "WANDB_PROJECT = \"tucano2-commerce\"\n",
261
+ "\n",
262
+ "# ── Eval callback ─────────────────────────────────────────────────────────────\n",
263
+ "EVAL_MAX_SAMPLES = 5\n",
264
+ "EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)\n",
265
+ "EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)\n",
266
+ "\n",
267
+ "# ── Backend ───────────────────────────────────────────────────────────────────\n",
268
+ "USE_VLLM = False\n",
269
+ "\n",
270
+ "# ── v3: Zero-advantage noise injection ────────────────────────────────────────\n",
271
+ "# Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training\n",
272
+ "# When all G completions get identical rewards, the advantage is undefined.\n",
273
+ "# Noise injection breaks ties without corrupting the signal.\n",
274
+ "ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups\n",
275
+ "\n",
276
+ "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
277
+ "\n",
278
+ "# ── Version assertion ─────────────────────────────────────────────────────────\n",
279
+ "import trl as _trl\n",
280
+ "assert _trl.__version__ == \"0.24.0\", (\n",
281
+ " f\"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\\n\"\n",
282
+ " \"Verify that GRPOTrainer._generate() still exists before proceeding.\"\n",
283
+ ")\n",
284
+ "\n",
285
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
286
+ "# W&B EARLY INIT — all cells log here, outputs survive notebook disconnect\n",
287
+ "# Prompts once for API key, then caches in ~/.netrc for future runs.\n",
288
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
289
+ "import wandb\n",
290
+ "\n",
291
+ "wandb.login() # reads from ~/.netrc if cached, prompts interactively if not\n",
292
+ "\n",
293
+ "RUN_NAME = f\"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}\"\n",
294
+ "wandb.init(\n",
295
+ " project=WANDB_PROJECT,\n",
296
+ " name=RUN_NAME,\n",
297
+ " config={\n",
298
+ " \"model_id\": MODEL_ID,\n",
299
+ " \"version\": \"v3\",\n",
300
+ " \"temperature\": TEMPERATURE,\n",
301
+ " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
302
+ " \"num_generations\": NUM_GENERATIONS,\n",
303
+ " \"learning_rate\": LEARNING_RATE,\n",
304
+ " \"beta\": BETA,\n",
305
+ " \"batch_size\": BATCH_SIZE,\n",
306
+ " \"grad_accum\": GRAD_ACCUM,\n",
307
+ " \"max_steps\": MAX_STEPS,\n",
308
+ " \"scale_rewards\": SCALE_REWARDS,\n",
309
+ " \"save_steps\": SAVE_STEPS,\n",
310
+ " \"eval_steps\": EVAL_STEPS,\n",
311
+ " \"eval_max_samples\": EVAL_MAX_SAMPLES,\n",
312
+ " \"eval_max_tokens\": EVAL_MAX_TOKENS,\n",
313
+ " \"eval_temperature\": EVAL_TEMPERATURE,\n",
314
+ " \"patience\": EARLY_STOPPING_PATIENCE,\n",
315
+ " \"delta\": EARLY_STOPPING_DELTA,\n",
316
+ " \"zero_adv_noise_std\": ZERO_ADV_NOISE_STD,\n",
317
+ " \"general_mix_ratio\": GENERAL_MIX_RATIO,\n",
318
+ " \"_ref_temperature\": \"Skywork-OR1 (2505.22312)\",\n",
319
+ " \"_ref_completion_length\": \"Dr. GRPO (2503.20783)\",\n",
320
+ " \"_ref_staged_rewards\": \"Reasoning-SQL (2503.23157)\",\n",
321
+ " \"_ref_zero_adv\": \"Skywork-OR1 (2505.22312)\",\n",
322
+ " },\n",
323
+ ")\n",
324
+ "\n",
325
+ "print(\"✓ v3 Config loaded\")\n",
326
+ "print(f\" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})\")\n",
327
+ "print(f\" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})\")\n",
328
+ "print(f\" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}\")\n",
329
+ "print(f\" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}\")\n",
330
+ "print(f\" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)\")\n",
331
+ "print(f\" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}\")\n",
332
+ "print(f\"✓ TRL {_trl.__version__} verified\")\n",
333
+ "print(f\"✓ W&B run: {wandb.run.url}\")\n",
334
+ "\n",
335
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
336
+ "# v3 VRAM BUDGET (L4 24GB)\n",
337
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
338
+ "# Model (NF4): ~3.5 GB\n",
339
+ "# KV Cache (8192 seq): ~3.0 GB\n",
340
+ "# Activations: ~4.0 GB\n",
341
+ "# Optimizer states: ~3.0 GB\n",
342
+ "# Generations (4×4096): ~8.0 GB\n",
343
+ "# ─────────────────────────────────\n",
344
+ "# Estimated total: ~21.5 GB\n",
345
+ "# Headroom: ~2.5 GB\n",
346
+ "#\n",
347
+ "# If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.\n",
348
+ "# Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.\n",
349
+ "# ══════════════════════════════════════════════════════════════════════════════"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "metadata": {},
355
+ "source": [
356
+ "---\n",
357
+ "\n",
358
+ "## Cell 4: Load SFT Adapter"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "print(\"Loading SFT adapter...\")\n",
368
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
369
+ " model_name=str(SFT_ADAPTER_DIR),\n",
370
+ " max_seq_length=MAX_SEQ_LENGTH,\n",
371
+ " load_in_4bit=True,\n",
372
+ " dtype=None,\n",
373
+ ")\n",
374
+ "\n",
375
+ "if tokenizer.pad_token is None:\n",
376
+ " tokenizer.pad_token = tokenizer.eos_token\n",
377
+ "\n",
378
+ "# Load chat template from base model (SFT adapter doesn't save it)\n",
379
+ "from transformers import AutoTokenizer\n",
380
+ "base_tok = AutoTokenizer.from_pretrained(MODEL_ID)\n",
381
+ "tokenizer.chat_template = base_tok.chat_template\n",
382
+ "del base_tok\n",
383
+ "\n",
384
+ "# v2: Force KV cache — Unsloth patching may reset this\n",
385
+ "model.config.use_cache = True\n",
386
+ "model.generation_config.use_cache = True\n",
387
+ "\n",
388
+ "print(f\"✓ Model loaded on {model.device}\")\n",
389
+ "print(f\" use_cache: {model.config.use_cache}\")\n",
390
+ "print(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n",
391
+ "print(f\" Chat template: {tokenizer.chat_template[:50]}...\")"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "markdown",
396
+ "metadata": {},
397
+ "source": [
398
+ "---\n",
399
+ "\n",
400
+ "## Cell 5: Single Inference Test\n",
401
+ "\n",
402
+ "**Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "FastLanguageModel.for_inference(model)\n",
412
+ "\n",
413
+ "test_msgs = [\n",
414
+ " {\"role\": \"system\", \"content\": SYSTEM_PT},\n",
415
+ " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
416
+ "]\n",
417
+ "text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)\n",
418
+ "inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
419
+ "\n",
420
+ "t0 = time.time()\n",
421
+ "outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
422
+ "elapsed = time.time() - t0\n",
423
+ "\n",
424
+ "response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
425
+ "gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
426
+ "\n",
427
+ "print(f\"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)\")\n",
428
+ "print(f\"Response length: {len(response)} chars, {gen_tokens} tokens\")\n",
429
+ "print(f\"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}\") # v3: should NOT hit ceiling with 4096\n",
430
+ "print(f\"closed_think: {'</think>' in response}\")\n",
431
+ "print(f\"\\n{'='*60}\")\n",
432
+ "print(response[:800])\n",
433
+ "\n",
434
+ "# ── Log to W&B ───────────────────────────────────────────────────────────────\n",
435
+ "wandb.log({\n",
436
+ " \"preflight/inference_time_s\": elapsed,\n",
437
+ " \"preflight/inference_tokens\": gen_tokens,\n",
438
+ " \"preflight/tok_per_s\": gen_tokens / elapsed,\n",
439
+ " \"preflight/hit_ceiling\": int(gen_tokens >= MAX_COMPLETION_LENGTH),\n",
440
+ " \"preflight/closed_think\": int(\"</think>\" in response),\n",
441
+ "})\n",
442
+ "wandb.summary[\"preflight/inference_response_preview\"] = response[:500]"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "metadata": {},
448
+ "source": [
449
+ "---\n",
450
+ "\n",
451
+ "## Cell 5b: KV Cache Diagnostic"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": null,
457
+ "metadata": {},
458
+ "outputs": [],
459
+ "source": [
460
+ "import time\n",
461
+ "FastLanguageModel.for_inference(model)\n",
462
+ "\n",
463
+ "_kv_msgs = [{\"role\": \"system\", \"content\": SYSTEM_PT},\n",
464
+ " {\"role\": \"user\", \"content\": \"Qual a categoria de reclamação mais frequente?\"}]\n",
465
+ "_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n",
466
+ "_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n",
467
+ "\n",
468
+ "_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\n",
469
+ "with torch.no_grad():\n",
470
+ " for _step in range(50):\n",
471
+ " _t0 = time.time()\n",
472
+ " seq_len = _generated.shape[1]\n",
473
+ " if _past is None:\n",
474
+ " _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n",
475
+ " else:\n",
476
+ " _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n",
477
+ " _out = model(\n",
478
+ " input_ids=_generated[:, -1:] if _past else _generated,\n",
479
+ " position_ids=_position_ids,\n",
480
+ " attention_mask=torch.ones(1, seq_len, device=model.device),\n",
481
+ " past_key_values=_past,\n",
482
+ " use_cache=True,\n",
483
+ " return_dict=True,\n",
484
+ " )\n",
485
+ " _past = _out.past_key_values\n",
486
+ " _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n",
487
+ " _generated = torch.cat([_generated, _next], dim=1)\n",
488
+ " _token_times.append(time.time() - _t0)\n",
489
+ "\n",
490
+ "_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\n",
491
+ "print(f\"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\n",
492
+ "print(f\"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\n",
493
+ "print(f\"Ratio last/first: {_ratio:.1f}x\")\n",
494
+ "if _ratio < 3:\n",
495
+ " print(\"✓ KV cache is working correctly\")\n",
496
+ "elif _ratio < 6:\n",
497
+ " print(\"⚠ KV cache may be degraded — check model.config.use_cache\")\n",
498
+ "else:\n",
499
+ " print(\"✗ KV cache BROKEN — GRPO generation will be catastrophically slow.\")\n",
500
+ "\n",
501
+ "# ── Log to W&B ─────────────────────────���─────────────────────────────────────\n",
502
+ "wandb.log({\n",
503
+ " \"preflight/kv_cache_ratio\": _ratio,\n",
504
+ " \"preflight/kv_cache_ok\": int(_ratio < 3),\n",
505
+ "})\n",
506
+ "\n",
507
+ "del _past, _generated, _kv_inputs, _token_times, _out\n",
508
+ "gc.collect()\n",
509
+ "if torch.cuda.is_available(): torch.cuda.empty_cache()"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": [
516
+ "---\n",
517
+ "\n",
518
+ "## Cell 6: Reward Functions v3\n",
519
+ "\n",
520
+ "**v3 changes:**\n",
521
+ "- Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)\n",
522
+ "- Zero-advantage noise injection (Skywork-OR1, 2505.22312)\n",
523
+ "- Extraction reward redesigned for completion-length-friendly scoring"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": null,
529
+ "metadata": {},
530
+ "outputs": [],
531
+ "source": [
532
+ "def strip_think(text: str) -> str:\n",
533
+ " \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n",
534
+ " return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n",
535
+ "\n",
536
+ "\n",
537
+ "def has_think_block(text: str) -> bool:\n",
538
+ " \"\"\"Check if text contains a non-empty <think> block.\"\"\"\n",
539
+ " return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n",
540
+ "\n",
541
+ "\n",
542
+ "def _classify_task_type(prompt_text: str) -> str:\n",
543
+ " \"\"\"Classify prompt into task type by keywords.\"\"\"\n",
544
+ " p = prompt_text.lower()\n",
545
+ " if \"retorne um objeto json\" in p or \"extraia dados\" in p:\n",
546
+ " return \"extraction\"\n",
547
+ " elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
548
+ " return \"push\"\n",
549
+ " elif \"perfil do cliente\" in p:\n",
550
+ " return \"insights\"\n",
551
+ " else:\n",
552
+ " return \"sql_qa\"\n",
553
+ "\n",
554
+ "\n",
555
+ "def _json_similarity(text: str) -> float:\n",
556
+ " \"\"\"Rough heuristic: how JSON-like is this text? 0.0 to 1.0.\"\"\"\n",
557
+ " text = text.strip()\n",
558
+ " if not text:\n",
559
+ " return 0.0\n",
560
+ " score = 0.0\n",
561
+ " if text.startswith(\"{\") and text.endswith(\"}\"):\n",
562
+ " score += 0.5\n",
563
+ " if '\"' in text:\n",
564
+ " score += 0.2\n",
565
+ " if \":\" in text:\n",
566
+ " score += 0.2\n",
567
+ " if \",\" in text:\n",
568
+ " score += 0.1\n",
569
+ " return min(score, 1.0)\n",
570
+ "\n",
571
+ "\n",
572
+ "def _string_similarity(a: str, b: str) -> float:\n",
573
+ " \"\"\"Simple Jaccard-like similarity for short strings. 0.0 to 1.0.\"\"\"\n",
574
+ " if not a or not b:\n",
575
+ " return 0.0\n",
576
+ " a_set = set(a.split())\n",
577
+ " b_set = set(b.split())\n",
578
+ " intersection = len(a_set & b_set)\n",
579
+ " union = len(a_set | b_set)\n",
580
+ " return intersection / union if union > 0 else 0.0\n",
581
+ "\n",
582
+ "\n",
583
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
584
+ "# v3 STAGED REWARD DESIGN\n",
585
+ "# Reference: Reasoning-SQL (2503.23157) §3.2\n",
586
+ "#\n",
587
+ "# Each reward function scores THREE stages independently:\n",
588
+ "# Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?\n",
589
+ "# Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?\n",
590
+ "# Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?\n",
591
+ "#\n",
592
+ "# Format rewards converge first (easy to learn), which stabilizes training\n",
593
+ "# and enables the model to then learn harder task-specific skills.\n",
594
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
595
+ "\n",
596
+ "\n",
597
+ "def reward_extraction(completion: str) -> float:\n",
598
+ " \"\"\"Staged reward for structured extraction (max 1.0).\"\"\"\n",
599
+ " answer = strip_think(completion)\n",
600
+ "\n",
601
+ " # ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────\n",
602
+ " r_format = 0.0\n",
603
+ " if has_think_block(completion):\n",
604
+ " r_format += 0.1 # Used reasoning\n",
605
+ "\n",
606
+ " try:\n",
607
+ " data = json.loads(answer)\n",
608
+ " if isinstance(data, dict):\n",
609
+ " r_format += 0.1 # Valid JSON object\n",
610
+ " except (json.JSONDecodeError, TypeError):\n",
611
+ " r_format += 0.05 * _json_similarity(answer)\n",
612
+ " return min(r_format, 0.2)\n",
613
+ "\n",
614
+ " if not isinstance(data, dict):\n",
615
+ " return min(r_format, 0.2)\n",
616
+ "\n",
617
+ " # ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────\n",
618
+ " r_partial = 0.0\n",
619
+ "\n",
620
+ " present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n",
621
+ " r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))\n",
622
+ "\n",
623
+ " type_checks = 0\n",
624
+ " type_total = 0\n",
625
+ " for field in EXTRACTION_FIELDS:\n",
626
+ " if field not in data:\n",
627
+ " continue\n",
628
+ " type_total += 1\n",
629
+ " val = data[field]\n",
630
+ " if field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
631
+ " if isinstance(val, bool):\n",
632
+ " type_checks += 1\n",
633
+ " elif field in (\"sentiment_score\",):\n",
634
+ " if isinstance(val, (int, float)):\n",
635
+ " type_checks += 1\n",
636
+ " elif field in (\"main_complaint\", \"sentiment\", \"complaint_category\", \"churn_risk\", \"repeat_intent\"):\n",
637
+ " if isinstance(val, str):\n",
638
+ " type_checks += 1\n",
639
+ " if type_total > 0:\n",
640
+ " r_partial += 0.15 * (type_checks / type_total)\n",
641
+ "\n",
642
+ " # ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────\n",
643
+ " r_task = 0.0\n",
644
+ " cat_checks = 0\n",
645
+ " cat_total = 0\n",
646
+ "\n",
647
+ " checks = [\n",
648
+ " (\"sentiment\", lambda v: v in VALID_SENTIMENTS),\n",
649
+ " (\"complaint_category\", lambda v: v in VALID_CATEGORIES),\n",
650
+ " (\"churn_risk\", lambda v: v in VALID_CHURN),\n",
651
+ " (\"repeat_intent\", lambda v: v in VALID_REPEAT),\n",
652
+ " (\"sentiment_score\", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),\n",
653
+ " ]\n",
654
+ " for field, validator in checks:\n",
655
+ " cat_total += 1\n",
656
+ " if field in data and validator(data[field]):\n",
657
+ " cat_checks += 1\n",
658
+ "\n",
659
+ " for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
660
+ " cat_total += 1\n",
661
+ " if bool_field in data and isinstance(data[bool_field], bool):\n",
662
+ " cat_checks += 1\n",
663
+ "\n",
664
+ " if cat_total > 0:\n",
665
+ " r_task += 0.35 * (cat_checks / cat_total)\n",
666
+ "\n",
667
+ " if \"main_complaint\" in data and isinstance(data[\"main_complaint\"], str):\n",
668
+ " complaint = data[\"main_complaint\"].strip()\n",
669
+ " if len(complaint) > 10:\n",
670
+ " r_task += 0.15\n",
671
+ "\n",
672
+ " return min(r_format + r_partial + r_task, 1.0)\n",
673
+ "\n",
674
+ "\n",
675
+ "def reward_sql_qa(completion: str) -> float:\n",
676
+ " \"\"\"Staged reward for SQL Q&A (max 1.0).\"\"\"\n",
677
+ " answer = strip_think(completion)\n",
678
+ "\n",
679
+ " # ── Stage 1: FORMAT (max 0.2)\n",
680
+ " r_format = 0.0\n",
681
+ " if has_think_block(completion):\n",
682
+ " r_format += 0.1\n",
683
+ " if \"```\" in answer or re.search(r\"SELECT|FROM\", answer, re.IGNORECASE):\n",
684
+ " r_format += 0.1\n",
685
+ "\n",
686
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
687
+ " r_partial = 0.0\n",
688
+ " sql_keywords = r\"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING\"\n",
689
+ " matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))\n",
690
+ " r_partial += min(0.15, 0.03 * matches)\n",
691
+ " numbers = re.findall(r\"\\d+(?:[.,]\\d+)?\", answer)\n",
692
+ " r_partial += min(0.15, 0.03 * len(numbers))\n",
693
+ "\n",
694
+ " # ── Stage 3: TASK (max 0.5)\n",
695
+ " r_task = 0.0\n",
696
+ " length = len(answer)\n",
697
+ " if 50 <= length <= 600:\n",
698
+ " r_task += 0.25\n",
699
+ " elif length > 0:\n",
700
+ " r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)\n",
701
+ " explanation_markers = [\"para \", \"porque\", \"resultado\", \"mostra\", \"indica\", \"análise\"]\n",
702
+ " expl_matches = sum(1 for w in explanation_markers if w in answer.lower())\n",
703
+ " r_task += min(0.25, 0.05 * expl_matches)\n",
704
+ "\n",
705
+ " return min(r_format + r_partial + r_task, 1.0)\n",
706
+ "\n",
707
+ "\n",
708
+ "def reward_insights(completion: str) -> float:\n",
709
+ " \"\"\"Staged reward for insights (max 1.0).\"\"\"\n",
710
+ " answer = strip_think(completion)\n",
711
+ "\n",
712
+ " # ── Stage 1: FORMAT (max 0.2)\n",
713
+ " r_format = 0.0\n",
714
+ " if has_think_block(completion):\n",
715
+ " r_format += 0.1\n",
716
+ " structure_marks = len(re.findall(r\"^[-•*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n",
717
+ " r_format += min(0.1, 0.02 * structure_marks)\n",
718
+ "\n",
719
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
720
+ " r_partial = 0.0\n",
721
+ " length = len(answer)\n",
722
+ " if 100 <= length <= 1200:\n",
723
+ " r_partial += 0.15\n",
724
+ " elif length > 0:\n",
725
+ " r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)\n",
726
+ " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto\", answer, re.IGNORECASE)\n",
727
+ " r_partial += min(0.15, 0.01 * len(pt_markers))\n",
728
+ "\n",
729
+ " # ── Stage 3: TASK (max 0.5)\n",
730
+ " r_task = 0.0\n",
731
+ " action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n",
732
+ " \"priorizar\", \"investir\", \"otimizar\", \"estratégi\", \"suger\",\n",
733
+ " \"consider\", \"ação\", \"plano\"]\n",
734
+ " matches = sum(1 for w in action_words if w in answer.lower())\n",
735
+ " r_task += min(0.3, 0.06 * matches)\n",
736
+ " data_refs = len(re.findall(r\"\\d+%|R\\$\\s*\\d|média|percentual|comparad|taxa\", answer, re.IGNORECASE))\n",
737
+ " r_task += min(0.2, 0.04 * data_refs)\n",
738
+ "\n",
739
+ " return min(r_format + r_partial + r_task, 1.0)\n",
740
+ "\n",
741
+ "\n",
742
+ "def reward_push(completion: str) -> float:\n",
743
+ " \"\"\"Staged reward for push notifications (max 1.0).\"\"\"\n",
744
+ " answer = strip_think(completion)\n",
745
+ " if not answer:\n",
746
+ " return 0.0\n",
747
+ "\n",
748
+ " # ── Stage 1: FORMAT (max 0.2)\n",
749
+ " r_format = 0.0\n",
750
+ " if has_think_block(completion):\n",
751
+ " r_format += 0.05\n",
752
+ " length = len(answer)\n",
753
+ " if length <= 160:\n",
754
+ " r_format += 0.15\n",
755
+ " elif length <= 300:\n",
756
+ " r_format += 0.1\n",
757
+ " else:\n",
758
+ " r_format += 0.05\n",
759
+ "\n",
760
+ " # ── Stage 2: PARTIAL (max 0.3)\n",
761
+ " r_partial = 0.0\n",
762
+ " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua\", answer, re.IGNORECASE)\n",
763
+ " r_partial += min(0.15, 0.02 * len(pt_markers))\n",
764
+ " if re.search(r\"[!?]|[\\U0001F600-\\U0001F64F]|[\\U0001F300-\\U0001F5FF]\", answer):\n",
765
+ " r_partial += 0.05\n",
766
+ " if len(answer.split()) >= 5:\n",
767
+ " r_partial += 0.1\n",
768
+ "\n",
769
+ " # ── Stage 3: TASK (max 0.5)\n",
770
+ " r_task = 0.0\n",
771
+ " if length <= 120:\n",
772
+ " r_task += 0.25\n",
773
+ " else:\n",
774
+ " r_task += 0.25 * max(0, 1 - (length - 120) / 120)\n",
775
+ " generic_phrases = [\n",
776
+ " \"olá! como podemos ajudar\", \"obrigado pela sua compra\",\n",
777
+ " \"seu pedido foi confirmado\", \"agradecemos sua preferência\",\n",
778
+ " ]\n",
779
+ " max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)\n",
780
+ " r_task += 0.25 * (1 - max_similarity)\n",
781
+ "\n",
782
+ " return min(r_format + r_partial + r_task, 1.0)\n",
783
+ "\n",
784
+ "\n",
785
+ "def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n",
786
+ " \"\"\"\n",
787
+ " Master reward function v3: dispatches by task type + zero-advantage noise.\n",
788
+ " \"\"\"\n",
789
+ " rewards = []\n",
790
+ " for completion, prompt in zip(completions, prompts):\n",
791
+ " if isinstance(completion, list):\n",
792
+ " comp_text = completion[-1][\"content\"] if completion else \"\"\n",
793
+ " else:\n",
794
+ " comp_text = str(completion)\n",
795
+ "\n",
796
+ " if isinstance(prompt, list):\n",
797
+ " prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
798
+ " else:\n",
799
+ " prompt_text = str(prompt)\n",
800
+ "\n",
801
+ " task = _classify_task_type(prompt_text)\n",
802
+ "\n",
803
+ " if task == \"extraction\":\n",
804
+ " rewards.append(reward_extraction(comp_text))\n",
805
+ " elif task == \"sql_qa\":\n",
806
+ " rewards.append(reward_sql_qa(comp_text))\n",
807
+ " elif task == \"insights\":\n",
808
+ " rewards.append(reward_insights(comp_text))\n",
809
+ " elif task == \"push\":\n",
810
+ " rewards.append(reward_push(comp_text))\n",
811
+ " else:\n",
812
+ " r = 0.15 if has_think_block(comp_text) else 0.0\n",
813
+ " r += 0.2 if comp_text.strip() else 0.0\n",
814
+ " rewards.append(r)\n",
815
+ "\n",
816
+ " # ── v3: Zero-advantage noise injection ────────────────────────────────────\n",
817
+ " if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:\n",
818
+ " for i in range(0, len(rewards), NUM_GENERATIONS):\n",
819
+ " group = rewards[i:i+NUM_GENERATIONS]\n",
820
+ " if len(group) < 2:\n",
821
+ " continue\n",
822
+ " if max(group) - min(group) < 0.001:\n",
823
+ " for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):\n",
824
+ " rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)\n",
825
+ "\n",
826
+ " return rewards\n",
827
+ "\n",
828
+ "\n",
829
+ "print(\"✓ v3 Reward functions defined (staged: format → partial → task)\")"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "markdown",
834
+ "metadata": {},
835
+ "source": [
836
+ "---\n",
837
+ "\n",
838
+ "## Cell 7: Reward Calibration\n",
839
+ "\n",
840
+ "**Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38)."
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": null,
846
+ "metadata": {},
847
+ "outputs": [],
848
+ "source": [
849
+ "train_path = DATA_DIR / \"pairs\" / \"train.jsonl\"\n",
850
+ "\n",
851
+ "by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
852
+ "with open(train_path) as f:\n",
853
+ " for line in f:\n",
854
+ " row = json.loads(line)\n",
855
+ " convs = row[\"conversations\"]\n",
856
+ " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
857
+ " if not prompt_msgs:\n",
858
+ " continue\n",
859
+ " user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
860
+ " task = _classify_task_type(user_text)\n",
861
+ " by_type[task].append(prompt_msgs)\n",
862
+ "\n",
863
+ "print(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n",
864
+ "\n",
865
+ "rng = random.Random(42)\n",
866
+ "cal_samples = []\n",
867
+ "for task_type in [\"extraction\", \"extraction\", \"sql_qa\", \"sql_qa\", \"insights\", \"insights\", \"push\", \"push\"]:\n",
868
+ " cal_samples.append(rng.choice(by_type[task_type]))\n",
869
+ "\n",
870
+ "FastLanguageModel.for_inference(model)\n",
871
+ "print(f\"\\nReward calibration v3 ({len(cal_samples)} samples):\")\n",
872
+ "print(\"-\" * 70)\n",
873
+ "\n",
874
+ "cal_rewards = []\n",
875
+ "cal_rows = [] # collect per-sample data for W&B Table\n",
876
+ "for i, msgs in enumerate(cal_samples):\n",
877
+ " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
878
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
879
+ " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
880
+ " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
881
+ " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
882
+ "\n",
883
+ " r = commerce_reward_fn([response], [text])[0]\n",
884
+ " cal_rewards.append(r)\n",
885
+ " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
886
+ " has_answer = \"</think>\" in response\n",
887
+ " answer_preview = strip_think(response)[:100] if has_answer else \"[stuck in <think>]\"\n",
888
+ " task = _classify_task_type(text)\n",
889
+ " print(f\" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'HIT' if hit_ceiling else 'ok':6s} | {answer_preview}\")\n",
890
+ "\n",
891
+ " cal_rows.append([i, task, r, gen_tokens, hit_ceiling, has_answer, answer_preview])\n",
892
+ "\n",
893
+ "print(f\"\\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}\")\n",
894
+ "print(f\"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70\")\n",
895
+ "print(f\"Variance > 0: {len(set(cal_rewards)) > 1}\")\n",
896
+ "\n",
897
+ "# ── Log to W&B ───────────────────────────────────────────────────────────────\n",
898
+ "cal_table = wandb.Table(\n",
899
+ " columns=[\"sample\", \"task\", \"reward\", \"tokens\", \"hit_ceiling\", \"closed_think\", \"answer_preview\"],\n",
900
+ " data=cal_rows,\n",
901
+ ")\n",
902
+ "wandb.log({\n",
903
+ " \"calibration/mean_reward\": sum(cal_rewards) / len(cal_rewards),\n",
904
+ " \"calibration/min_reward\": min(cal_rewards),\n",
905
+ " \"calibration/max_reward\": max(cal_rewards),\n",
906
+ " \"calibration/has_variance\": int(len(set(cal_rewards)) > 1),\n",
907
+ " \"calibration/samples\": cal_table,\n",
908
+ "})\n",
909
+ "wandb.config.update({\"prompts_by_type\": {k: len(v) for k, v in by_type.items()}}, allow_val_change=True)"
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "markdown",
914
+ "metadata": {},
915
+ "source": [
916
+ "---\n",
917
+ "\n",
918
+ "## Cell 8: Dataset Preparation v3"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": null,
924
+ "metadata": {},
925
+ "outputs": [],
926
+ "source": [
927
+ "from datasets import Dataset\n",
928
+ "\n",
929
+ "def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,\n",
930
+ " general_mix=GENERAL_MIX_RATIO, seed=42):\n",
931
+ " rng = random.Random(seed)\n",
932
+ "\n",
933
+ " train_pools = {}\n",
934
+ " eval_records = []\n",
935
+ " for task, pool in by_type.items():\n",
936
+ " shuffled = pool.copy()\n",
937
+ " rng.shuffle(shuffled)\n",
938
+ " n_eval = max(1, int(len(shuffled) * eval_ratio))\n",
939
+ " eval_records.extend(shuffled[:n_eval])\n",
940
+ " train_pools[task] = shuffled[n_eval:]\n",
941
+ "\n",
942
+ " if n_prompts is None:\n",
943
+ " train_records = []\n",
944
+ " for task, pool in train_pools.items():\n",
945
+ " train_records.extend(pool)\n",
946
+ " rng.shuffle(train_records)\n",
947
+ " else:\n",
948
+ " targets = {\n",
949
+ " \"extraction\": int(n_prompts * 0.4),\n",
950
+ " \"sql_qa\": int(n_prompts * 0.4),\n",
951
+ " \"insights\": int(n_prompts * 0.1),\n",
952
+ " \"push\": int(n_prompts * 0.1),\n",
953
+ " }\n",
954
+ " train_records = []\n",
955
+ " for task, target_n in targets.items():\n",
956
+ " pool = train_pools[task]\n",
957
+ " n = min(target_n, len(pool))\n",
958
+ " train_records.extend(rng.sample(pool, n))\n",
959
+ " rng.shuffle(train_records)\n",
960
+ "\n",
961
+ " general_path = DATA_DIR / \"pairs\" / \"general_reasoning.jsonl\"\n",
962
+ " if general_mix > 0 and general_path.exists():\n",
963
+ " general_records = []\n",
964
+ " with open(general_path) as f:\n",
965
+ " for line in f:\n",
966
+ " row = json.loads(line)\n",
967
+ " convs = row[\"conversations\"]\n",
968
+ " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
969
+ " if prompt_msgs:\n",
970
+ " general_records.append(prompt_msgs)\n",
971
+ " n_general = int(len(train_records) * general_mix / (1 - general_mix))\n",
972
+ " n_general = min(n_general, len(general_records))\n",
973
+ " if n_general > 0:\n",
974
+ " train_records.extend(rng.sample(general_records, n_general))\n",
975
+ " rng.shuffle(train_records)\n",
976
+ " print(f\" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)\")\n",
977
+ " elif general_mix > 0:\n",
978
+ " print(f\" general_reasoning.jsonl not found — skipping mix\")\n",
979
+ "\n",
980
+ " task_dist = {}\n",
981
+ " for record in train_records:\n",
982
+ " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
983
+ " task = _classify_task_type(user_text)\n",
984
+ " task_dist[task] = task_dist.get(task, 0) + 1\n",
985
+ "\n",
986
+ " n_domain = len(train_records)\n",
987
+ " steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)\n",
988
+ "\n",
989
+ " print(f\"v3 Dataset split (eval_ratio={eval_ratio}):\")\n",
990
+ " print(f\" train : {n_domain} prompts\")\n",
991
+ " print(f\" eval : {len(eval_records)} prompts\")\n",
992
+ " print(f\" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}\")\n",
993
+ " print(f\" steps/epoch: {n_domain} x {NUM_GENERATIONS} / ({BATCH_SIZE} x {GRAD_ACCUM}) = {steps_per_epoch}\")\n",
994
+ " print(f\" MAX_STEPS={MAX_STEPS} -> {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}\")\n",
995
+ "\n",
996
+ " train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n",
997
+ " eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n",
998
+ " return train_ds, eval_ds\n",
999
+ "\n",
1000
+ "\n",
1001
+ "train_dataset, eval_dataset = prepare_grpo_datasets_v3()\n",
1002
+ "dataset = train_dataset\n",
1003
+ "print(f\"\\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}\")\n",
1004
+ "\n",
1005
+ "# ── Log dataset sizes to W&B ─────────────────────────────────────────────────\n",
1006
+ "wandb.config.update({\n",
1007
+ " \"train_prompts\": len(train_dataset),\n",
1008
+ " \"eval_prompts\": len(eval_dataset),\n",
1009
+ "}, allow_val_change=True)"
1010
+ ]
1011
+ },
1012
+ {
1013
+ "cell_type": "markdown",
1014
+ "metadata": {},
1015
+ "source": [
1016
+ "---\n",
1017
+ "\n",
1018
+ "## Cell 9: Smoke Test\n",
1019
+ "\n",
1020
+ "**Gate:** Runs 1 step without OOM at new completion length (4096)."
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "code",
1025
+ "execution_count": null,
1026
+ "metadata": {},
1027
+ "outputs": [],
1028
+ "source": [
1029
+ "from trl import GRPOConfig, GRPOTrainer\n",
1030
+ "\n",
1031
+ "FastLanguageModel.for_training(model)\n",
1032
+ "\n",
1033
+ "smoke_config = GRPOConfig(\n",
1034
+ " output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n",
1035
+ " num_generations=NUM_GENERATIONS,\n",
1036
+ " scale_rewards=SCALE_REWARDS,\n",
1037
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1038
+ " max_steps=1,\n",
1039
+ " num_train_epochs=1,\n",
1040
+ " temperature=TEMPERATURE,\n",
1041
+ " per_device_train_batch_size=BATCH_SIZE,\n",
1042
+ " gradient_accumulation_steps=1,\n",
1043
+ " learning_rate=LEARNING_RATE,\n",
1044
+ " fp16=False,\n",
1045
+ " bf16=True,\n",
1046
+ " logging_steps=1,\n",
1047
+ " save_steps=999,\n",
1048
+ " report_to=\"none\",\n",
1049
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1050
+ " seed=42,\n",
1051
+ " remove_unused_columns=False,\n",
1052
+ ")\n",
1053
+ "\n",
1054
+ "smoke_trainer = GRPOTrainer(\n",
1055
+ " model=model,\n",
1056
+ " reward_funcs=commerce_reward_fn,\n",
1057
+ " args=smoke_config,\n",
1058
+ " train_dataset=dataset,\n",
1059
+ " tokenizer=tokenizer,\n",
1060
+ ")\n",
1061
+ "\n",
1062
+ "t0 = time.time()\n",
1063
+ "smoke_trainer.train()\n",
1064
+ "step_time = time.time() - t0\n",
1065
+ "\n",
1066
+ "vram_used = torch.cuda.max_memory_allocated() / 1e9\n",
1067
+ "vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
1068
+ "\n",
1069
+ "print(f\"\\n✓ Smoke test passed!\")\n",
1070
+ "print(f\" Step time (grad_accum=1): {step_time:.0f}s\")\n",
1071
+ "print(f\" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s\")\n",
1072
+ "print(f\" VRAM peak: {vram_used:.1f} GB / {vram_total:.1f} GB\")\n",
1073
+ "\n",
1074
+ "if vram_used > vram_total * 0.95:\n",
1075
+ " print(f\"\\n VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM\")\n",
1076
+ " print(f\" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072\")\n",
1077
+ " print(f\" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)\")\n",
1078
+ "\n",
1079
+ "# ── Log to W&B ───────────────────────────────────────────────────────────────\n",
1080
+ "wandb.log({\n",
1081
+ " \"smoke/step_time_s\": step_time,\n",
1082
+ " \"smoke/vram_peak_gb\": vram_used,\n",
1083
+ " \"smoke/vram_total_gb\": vram_total,\n",
1084
+ " \"smoke/vram_pct\": vram_used / vram_total,\n",
1085
+ " \"smoke/estimated_step_time_s\": step_time * GRAD_ACCUM,\n",
1086
+ "})\n",
1087
+ "\n",
1088
+ "del smoke_trainer\n",
1089
+ "gc.collect(); torch.cuda.empty_cache()"
1090
+ ]
1091
+ },
1092
+ {
1093
+ "cell_type": "markdown",
1094
+ "metadata": {},
1095
+ "source": [
1096
+ "---\n",
1097
+ "\n",
1098
+ "## Cell 10: Probe Run (3 steps)"
1099
+ ]
1100
+ },
1101
+ {
1102
+ "cell_type": "code",
1103
+ "execution_count": null,
1104
+ "metadata": {},
1105
+ "outputs": [],
1106
+ "source": [
1107
+ "FastLanguageModel.for_training(model)\n",
1108
+ "\n",
1109
+ "probe_config = GRPOConfig(\n",
1110
+ " output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
1111
+ " num_generations=NUM_GENERATIONS,\n",
1112
+ " scale_rewards=SCALE_REWARDS,\n",
1113
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1114
+ " max_steps=3,\n",
1115
+ " temperature=TEMPERATURE,\n",
1116
+ " num_train_epochs=NUM_EPOCHS,\n",
1117
+ " per_device_train_batch_size=BATCH_SIZE,\n",
1118
+ " gradient_accumulation_steps=GRAD_ACCUM,\n",
1119
+ " learning_rate=LEARNING_RATE,\n",
1120
+ " warmup_ratio=0.1,\n",
1121
+ " lr_scheduler_type=\"cosine\",\n",
1122
+ " fp16=False,\n",
1123
+ " bf16=True,\n",
1124
+ " logging_steps=1,\n",
1125
+ " disable_tqdm=True,\n",
1126
+ " logging_first_step=True,\n",
1127
+ " save_steps=999,\n",
1128
+ " report_to=\"none\",\n",
1129
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1130
+ " seed=42,\n",
1131
+ " remove_unused_columns=False,\n",
1132
+ ")\n",
1133
+ "\n",
1134
+ "probe_trainer = GRPOTrainer(\n",
1135
+ " model=model,\n",
1136
+ " reward_funcs=commerce_reward_fn,\n",
1137
+ " args=probe_config,\n",
1138
+ " train_dataset=dataset,\n",
1139
+ " tokenizer=tokenizer,\n",
1140
+ ")\n",
1141
+ "\n",
1142
+ "t0 = time.time()\n",
1143
+ "result = probe_trainer.train()\n",
1144
+ "elapsed = time.time() - t0\n",
1145
+ "\n",
1146
+ "has_gradient = abs(result.training_loss) >= 1e-6\n",
1147
+ "\n",
1148
+ "print(f\"\\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)\")\n",
1149
+ "print(f\" Train loss: {result.training_loss:.6f}\")\n",
1150
+ "print(f\" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h\")\n",
1151
+ "\n",
1152
+ "if not has_gradient:\n",
1153
+ " print(\" Loss is near-zero — reward variance may be insufficient\")\n",
1154
+ "else:\n",
1155
+ " print(\" ✓ Loss is non-zero — GRPO has gradient signal\")\n",
1156
+ "\n",
1157
+ "# ── Log to W&B ───────────────────────────────────────────────────────────────\n",
1158
+ "wandb.log({\n",
1159
+ " \"probe/loss\": result.training_loss,\n",
1160
+ " \"probe/time_per_step_s\": elapsed / 3,\n",
1161
+ " \"probe/estimated_full_run_h\": elapsed / 3 * MAX_STEPS / 3600,\n",
1162
+ " \"probe/has_gradient\": int(has_gradient),\n",
1163
+ "})\n",
1164
+ "\n",
1165
+ "del probe_trainer\n",
1166
+ "gc.collect(); torch.cuda.empty_cache()"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "cell_type": "markdown",
1171
+ "metadata": {},
1172
+ "source": [
1173
+ "---\n",
1174
+ "\n",
1175
+ "## Cell 11: Full Training Run v3"
1176
+ ]
1177
+ },
1178
+ {
1179
+ "cell_type": "code",
1180
+ "execution_count": null,
1181
+ "metadata": {},
1182
+ "outputs": [],
1183
+ "source": [
1184
+ "import shutil\n",
1185
+ "import torch\n",
1186
+ "from transformers import TrainerCallback\n",
1187
+ "from trl import GRPOConfig, GRPOTrainer\n",
1188
+ "\n",
1189
+ "# W&B run already active from Cell 3 — just update config with dataset counts\n",
1190
+ "wandb.config.update({\n",
1191
+ " \"train_prompts\": len(train_dataset),\n",
1192
+ " \"eval_prompts\": len(eval_dataset),\n",
1193
+ "}, allow_val_change=True)\n",
1194
+ "\n",
1195
+ "FRESH = True\n",
1196
+ "resume_from = None\n",
1197
+ "if FRESH and CHECKPOINT_DIR.exists():\n",
1198
+ " print(\"FRESH: deleting old checkpoints...\")\n",
1199
+ " shutil.rmtree(CHECKPOINT_DIR)\n",
1200
+ "elif CHECKPOINT_DIR.exists():\n",
1201
+ " checkpoints = sorted(\n",
1202
+ " [d for d in CHECKPOINT_DIR.iterdir()\n",
1203
+ " if d.is_dir() and d.name.startswith(\"checkpoint-\")],\n",
1204
+ " key=lambda d: int(d.name.split(\"-\")[-1]),\n",
1205
+ " )\n",
1206
+ " if checkpoints:\n",
1207
+ " resume_from = str(checkpoints[-1])\n",
1208
+ " print(f\"Resuming from: {resume_from}\")\n",
1209
+ "\n",
1210
+ "\n",
1211
+ "class UnslothGRPOTrainer(GRPOTrainer):\n",
1212
+ " \"\"\"Wraps generation with Unsloth for_inference()/for_training().\"\"\"\n",
1213
+ " def _generate(self, prompts, images):\n",
1214
+ " FastLanguageModel.for_inference(self.model)\n",
1215
+ " try:\n",
1216
+ " result = super()._generate(prompts, images)\n",
1217
+ " finally:\n",
1218
+ " FastLanguageModel.for_training(self.model)\n",
1219
+ " return result\n",
1220
+ "\n",
1221
+ "\n",
1222
+ "class EvalRewardCallback(TrainerCallback):\n",
1223
+ " \"\"\"v3: deterministic eval, per-task breakdown, patience=15.\"\"\"\n",
1224
+ " def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,\n",
1225
+ " delta=EARLY_STOPPING_DELTA):\n",
1226
+ " self.eval_records = eval_records\n",
1227
+ " self.reward_fn = reward_fn\n",
1228
+ " self.patience = patience\n",
1229
+ " self.delta = delta\n",
1230
+ " self.best_reward = -float(\"inf\")\n",
1231
+ " self.no_improve_count = 0\n",
1232
+ "\n",
1233
+ " def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n",
1234
+ " if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n",
1235
+ " return control\n",
1236
+ " tokenizer = processing_class\n",
1237
+ " if tokenizer is None:\n",
1238
+ " print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n",
1239
+ " return control\n",
1240
+ "\n",
1241
+ " mean_reward, task_rewards = self._run_eval(model, tokenizer, args)\n",
1242
+ " improved = mean_reward > self.best_reward + self.delta\n",
1243
+ " status = \"improved\" if improved else f\"no gain ({self.no_improve_count + 1}/{self.patience})\"\n",
1244
+ "\n",
1245
+ " log_dict = {\n",
1246
+ " \"eval/mean_reward\": mean_reward,\n",
1247
+ " \"eval/best_reward\": max(self.best_reward, mean_reward),\n",
1248
+ " \"eval/no_improve_count\": self.no_improve_count,\n",
1249
+ " }\n",
1250
+ " for task, rewards in task_rewards.items():\n",
1251
+ " if rewards:\n",
1252
+ " log_dict[f\"eval/{task}_reward\"] = sum(rewards) / len(rewards)\n",
1253
+ " wandb.log(log_dict, step=state.global_step)\n",
1254
+ "\n",
1255
+ " print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n",
1256
+ " for task, rewards in task_rewards.items():\n",
1257
+ " if rewards:\n",
1258
+ " print(f\" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})\")\n",
1259
+ "\n",
1260
+ " if improved:\n",
1261
+ " self.best_reward = mean_reward\n",
1262
+ " self.no_improve_count = 0\n",
1263
+ " else:\n",
1264
+ " self.no_improve_count += 1\n",
1265
+ " if self.no_improve_count >= self.patience:\n",
1266
+ " print(f\"[EarlyStopping] No improvement >= {self.delta} for {self.patience} consecutive evals. Halting.\")\n",
1267
+ " wandb.log({\"early_stop/step\": state.global_step}, step=state.global_step)\n",
1268
+ " control.should_training_stop = True\n",
1269
+ " return control\n",
1270
+ "\n",
1271
+ " def _run_eval(self, model, tokenizer, args):\n",
1272
+ " FastLanguageModel.for_inference(model)\n",
1273
+ " rewards = []\n",
1274
+ " task_rewards = {}\n",
1275
+ " subset = self.eval_records[:EVAL_MAX_SAMPLES]\n",
1276
+ " for record in subset:\n",
1277
+ " msgs = record[\"prompt\"]\n",
1278
+ " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
1279
+ " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True,\n",
1280
+ " max_length=args.max_prompt_length).to(model.device)\n",
1281
+ " with torch.no_grad():\n",
1282
+ " out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,\n",
1283
+ " temperature=EVAL_TEMPERATURE, do_sample=True)\n",
1284
+ " resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1285
+ " r = self.reward_fn([resp], [text])[0]\n",
1286
+ " rewards.append(r)\n",
1287
+ " user_text = \" \".join(m.get(\"content\", \"\") for m in msgs if m.get(\"role\") == \"user\")\n",
1288
+ " task = _classify_task_type(user_text)\n",
1289
+ " task_rewards.setdefault(task, []).append(r)\n",
1290
+ " FastLanguageModel.for_training(model)\n",
1291
+ " mean = sum(rewards) / len(rewards) if rewards else 0.0\n",
1292
+ " return mean, task_rewards\n",
1293
+ "\n",
1294
+ "\n",
1295
+ "class EntropyMonitorCallback(TrainerCallback):\n",
1296
+ " \"\"\"v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 S4).\"\"\"\n",
1297
+ " def __init__(self):\n",
1298
+ " self.consecutive_ceiling_hits = 0\n",
1299
+ "\n",
1300
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
1301
+ " if not logs:\n",
1302
+ " return\n",
1303
+ " step = state.global_step\n",
1304
+ " monitor = {}\n",
1305
+ " comp_len = logs.get(\"completion_length\", 0)\n",
1306
+ " if comp_len > 0:\n",
1307
+ " ratio = comp_len / MAX_COMPLETION_LENGTH\n",
1308
+ " monitor[\"monitor/completion_ratio\"] = ratio\n",
1309
+ " if ratio > 0.95:\n",
1310
+ " self.consecutive_ceiling_hits += 1\n",
1311
+ " if self.consecutive_ceiling_hits >= 3:\n",
1312
+ " print(f\"Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.\")\n",
1313
+ " else:\n",
1314
+ " self.consecutive_ceiling_hits = 0\n",
1315
+ " reward_std = logs.get(\"reward_std\", logs.get(\"rewards/commerce_reward_fn/std\", 0))\n",
1316
+ " if reward_std is not None:\n",
1317
+ " monitor[\"monitor/reward_std\"] = reward_std\n",
1318
+ " if reward_std < 0.01:\n",
1319
+ " print(f\"Step {step}: reward_std={reward_std:.4f} — near-zero variance\")\n",
1320
+ " clip_high = logs.get(\"clip_ratio/high_mean\", 0)\n",
1321
+ " clip_low = logs.get(\"clip_ratio/low_mean\", 0)\n",
1322
+ " if clip_high is not None and clip_low is not None:\n",
1323
+ " total_clip = clip_high + abs(clip_low)\n",
1324
+ " monitor[\"monitor/total_clip_ratio\"] = total_clip\n",
1325
+ " if total_clip > 0.01 and step > 10:\n",
1326
+ " print(f\"Step {step}: clip_ratio={total_clip:.3f} — policy is updating\")\n",
1327
+ " if monitor and wandb.run:\n",
1328
+ " wandb.log(monitor, step=step)\n",
1329
+ "\n",
1330
+ "\n",
1331
+ "FastLanguageModel.for_training(model)\n",
1332
+ "\n",
1333
+ "grpo_config = GRPOConfig(\n",
1334
+ " output_dir=str(CHECKPOINT_DIR),\n",
1335
+ " num_generations=NUM_GENERATIONS,\n",
1336
+ " scale_rewards=SCALE_REWARDS,\n",
1337
+ " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1338
+ " temperature=TEMPERATURE,\n",
1339
+ " max_steps=MAX_STEPS,\n",
1340
+ " num_train_epochs=NUM_EPOCHS,\n",
1341
+ " per_device_train_batch_size=BATCH_SIZE,\n",
1342
+ " gradient_accumulation_steps=GRAD_ACCUM,\n",
1343
+ " learning_rate=LEARNING_RATE,\n",
1344
+ " warmup_ratio=0.1,\n",
1345
+ " lr_scheduler_type=\"cosine\",\n",
1346
+ " fp16=False,\n",
1347
+ " bf16=True,\n",
1348
+ " logging_steps=1,\n",
1349
+ " logging_first_step=True,\n",
1350
+ " disable_tqdm=True,\n",
1351
+ " save_steps=SAVE_STEPS,\n",
1352
+ " save_total_limit=SAVE_TOTAL_LIMIT,\n",
1353
+ " save_only_model=True,\n",
1354
+ " eval_steps=EVAL_STEPS,\n",
1355
+ " report_to=\"wandb\",\n",
1356
+ " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1357
+ " seed=42,\n",
1358
+ " remove_unused_columns=False,\n",
1359
+ " **({\"use_vllm\": True, \"vllm_mode\": \"colocate\",\n",
1360
+ " \"vllm_enable_sleep_mode\": True} if USE_VLLM else {}),\n",
1361
+ ")\n",
1362
+ "\n",
1363
+ "eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)\n",
1364
+ "entropy_cb = EntropyMonitorCallback()\n",
1365
+ "\n",
1366
+ "TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer\n",
1367
+ "trainer = TrainerClass(\n",
1368
+ " model=model,\n",
1369
+ " reward_funcs=commerce_reward_fn,\n",
1370
+ " args=grpo_config,\n",
1371
+ " train_dataset=train_dataset,\n",
1372
+ " processing_class=tokenizer,\n",
1373
+ " callbacks=[eval_cb, entropy_cb],\n",
1374
+ ")\n",
1375
+ "\n",
1376
+ "print(f\"{'='*70}\")\n",
1377
+ "print(f\"GRPO v3 Training — Ready to Launch\")\n",
1378
+ "print(f\"{'='*70}\")\n",
1379
+ "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1380
+ "print(f\" Max steps: {MAX_STEPS}\")\n",
1381
+ "print(f\" Temperature: {TEMPERATURE} (v2 was 0.8)\")\n",
1382
+ "print(f\" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)\")\n",
1383
+ "print(f\" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)\")\n",
1384
+ "print(f\" Learning rate: {LEARNING_RATE} (v2 was 5e-7)\")\n",
1385
+ "print(f\" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})\")\n",
1386
+ "print(f\" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples x {EVAL_MAX_TOKENS} tok)\")\n",
1387
+ "print(f\" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)\")\n",
1388
+ "print(f\" Resume: {resume_from is not None}\")\n",
1389
+ "print(f\" W&B run: {wandb.run.url}\")\n",
1390
+ "print(f\"{'='*70}\")\n",
1391
+ "\n",
1392
+ "t_start = time.time()\n",
1393
+ "result = trainer.train(resume_from_checkpoint=resume_from)\n",
1394
+ "elapsed = time.time() - t_start\n",
1395
+ "\n",
1396
+ "# ── Log final training metrics (run stays open for save + validation) ────────\n",
1397
+ "wandb.log({\n",
1398
+ " \"train/final_loss\": result.training_loss,\n",
1399
+ " \"train/duration_hours\": elapsed / 3600,\n",
1400
+ " \"train/total_steps\": result.global_step,\n",
1401
+ " \"eval/best_reward_final\": eval_cb.best_reward,\n",
1402
+ "})\n",
1403
+ "\n",
1404
+ "print(f\"\\n{'='*70}\")\n",
1405
+ "print(f\"GRPO v3 Training Complete\")\n",
1406
+ "print(f\" Loss: {result.training_loss:.6f}\")\n",
1407
+ "print(f\" Steps: {result.global_step}\")\n",
1408
+ "print(f\" Duration: {elapsed/3600:.1f}h\")\n",
1409
+ "print(f\" Best eval R: {eval_cb.best_reward:.4f}\")\n",
1410
+ "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1411
+ "print(f\" W&B run: {wandb.run.url}\")\n",
1412
+ "print(f\"{'='*70}\")"
1413
+ ]
1414
+ },
1415
+ {
1416
+ "cell_type": "markdown",
1417
+ "id": "b5a463d0",
1418
+ "metadata": {},
1419
+ "source": [
1420
+ "---\n",
1421
+ "\n",
1422
+ "## Cell 12: Save Adapter"
1423
+ ]
1424
+ },
1425
+ {
1426
+ "cell_type": "code",
1427
+ "execution_count": null,
1428
+ "metadata": {},
1429
+ "outputs": [],
1430
+ "source": [
1431
+ "GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\n",
1432
+ "model.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1433
+ "tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1434
+ "\n",
1435
+ "summary = {\n",
1436
+ " \"model_id\": MODEL_ID,\n",
1437
+ " \"sft_adapter\": str(SFT_ADAPTER_DIR),\n",
1438
+ " \"method\": \"GRPO\",\n",
1439
+ " \"version\": \"v3\",\n",
1440
+ " \"train_loss\": result.training_loss,\n",
1441
+ " \"best_eval_reward\": eval_cb.best_reward,\n",
1442
+ " \"num_prompts\": len(train_dataset),\n",
1443
+ " \"num_generations\": NUM_GENERATIONS,\n",
1444
+ " \"scale_rewards\": SCALE_REWARDS,\n",
1445
+ " \"temperature\": TEMPERATURE,\n",
1446
+ " \"learning_rate\": LEARNING_RATE,\n",
1447
+ " \"beta\": BETA,\n",
1448
+ " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
1449
+ " \"max_steps\": MAX_STEPS,\n",
1450
+ " \"actual_steps\": result.global_step,\n",
1451
+ " \"epochs\": NUM_EPOCHS,\n",
1452
+ " \"max_seq_length\": MAX_SEQ_LENGTH,\n",
1453
+ " \"duration_seconds\": round(elapsed),\n",
1454
+ " \"gpu\": \"L4\",\n",
1455
+ " \"platform\": \"vertex-ai-workbench\",\n",
1456
+ " \"v3_fixes\": [\n",
1457
+ " \"temperature=1.0 (Skywork-OR1)\",\n",
1458
+ " \"max_completion_length=4096 (Dr. GRPO)\",\n",
1459
+ " \"learning_rate=2e-6 (4x v2)\",\n",
1460
+ " \"beta=0.0 (Dr. GRPO)\",\n",
1461
+ " \"staged rewards (Reasoning-SQL)\",\n",
1462
+ " \"zero-advantage noise (Skywork-OR1)\",\n",
1463
+ " \"entropy monitoring callback\",\n",
1464
+ " ],\n",
1465
+ "}\n",
1466
+ "with open(GRPO_ADAPTER_DIR / \"training_summary.json\", \"w\") as f:\n",
1467
+ " json.dump(summary, f, indent=2)\n",
1468
+ "\n",
1469
+ "print(f\"✓ Adapter saved to {GRPO_ADAPTER_DIR}\")\n",
1470
+ "print(f\" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}\")"
1471
+ ]
1472
+ },
1473
+ {
1474
+ "cell_type": "markdown",
1475
+ "metadata": {},
1476
+ "source": [
1477
+ "FastLanguageModel.for_inference(model)\n",
1478
+ "\n",
1479
+ "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1480
+ "\n",
1481
+ "test_prompts = [\n",
1482
+ " {\"role\": \"user\", \"content\": (\n",
1483
+ " \"Analise esta avaliacao de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1484
+ " \"nota=2/5 | status=delivered\\ntitulo: decepcionado\\n\"\n",
1485
+ " \"texto: Produto veio com defeito e o vendedor nao respondeu.\\n\\n\"\n",
1486
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1487
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1488
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1489
+ " )},\n",
1490
+ " {\"role\": \"user\", \"content\": (\n",
1491
+ " \"Analise esta avaliacao de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1492
+ " \"nota=5/5 | status=delivered\\ntitulo: adorei!\\n\"\n",
1493
+ " \"texto: Entrega rapida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1494
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1495
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1496
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1497
+ " )},\n",
1498
+ " {\"role\": \"user\", \"content\": \"Quais sao as categorias de reclamacao mais frequentes e como afetam a nota media?\"},\n",
1499
+ " {\"role\": \"user\", \"content\": \"Analise a retencao de clientes afetados por product_quality.\"},\n",
1500
+ " {\"role\": \"user\", \"content\": (\n",
1501
+ " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1502
+ " \"- Reclamacao: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1503
+ " \"Este cliente deve receber uma notificacao de reengajamento?\"\n",
1504
+ " )},\n",
1505
+ " {\"role\": \"user\", \"content\": \"Compare a satisfacao de clientes em SP vs RJ.\"},\n",
1506
+ " {\"role\": \"user\", \"content\": (\n",
1507
+ " \"Crie uma notificacao push de reengajamento para um cliente em SP \"\n",
1508
+ " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1509
+ " )},\n",
1510
+ "]\n",
1511
+ "\n",
1512
+ "print(\"=\" * 70)\n",
1513
+ "print(\"GRPO v3 Validation\")\n",
1514
+ "print(\"=\" * 70)\n",
1515
+ "\n",
1516
+ "v3_rewards = []\n",
1517
+ "val_rows = []\n",
1518
+ "for i, prompt in enumerate(test_prompts):\n",
1519
+ " messages = [system_msg, prompt]\n",
1520
+ " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1521
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1522
+ "\n",
1523
+ " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)\n",
1524
+ " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
1525
+ " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1526
+ "\n",
1527
+ " reward = commerce_reward_fn([response], [text])[0]\n",
1528
+ " v3_rewards.append(reward)\n",
1529
+ " answer = strip_think(response)\n",
1530
+ " task = _classify_task_type(prompt[\"content\"])\n",
1531
+ " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1532
+ "\n",
1533
+ " print(f\"\\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---\")\n",
1534
+ " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1535
+ " print(f\"Answer: {answer[:400]}\")\n",
1536
+ "\n",
1537
+ " val_rows.append([i + 1, task, reward, gen_tokens, hit_ceiling,\n",
1538
+ " prompt[\"content\"][:120], answer[:500]])\n",
1539
+ "\n",
1540
+ "v3_mean = sum(v3_rewards) / len(v3_rewards)\n",
1541
+ "v3_vs_v2 = (v3_mean - 0.54) / 0.54 * 100\n",
1542
+ "\n",
1543
+ "print(f\"\\n{'='*70}\")\n",
1544
+ "print(f\"v3 Validation Summary\")\n",
1545
+ "print(f\"{'='*70}\")\n",
1546
+ "print(f\" Mean reward: {v3_mean:.3f}\")\n",
1547
+ "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1548
+ "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1549
+ "print()\n",
1550
+ "print(f\" Comparison to baselines:\")\n",
1551
+ "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1552
+ "print(f\" GRPO v2 validation: mean=0.54\")\n",
1553
+ "print(f\" GRPO v3 validation: mean={v3_mean:.3f}\")\n",
1554
+ "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")\n",
1555
+ "\n",
1556
+ "# ── Log validation results to W&B ────────────────────────────────────────────\n",
1557
+ "val_table = wandb.Table(\n",
1558
+ " columns=[\"sample\", \"task\", \"reward\", \"tokens\", \"hit_ceiling\", \"prompt_preview\", \"answer_preview\"],\n",
1559
+ " data=val_rows,\n",
1560
+ ")\n",
1561
+ "wandb.log({\n",
1562
+ " \"validation/mean_reward\": v3_mean,\n",
1563
+ " \"validation/min_reward\": min(v3_rewards),\n",
1564
+ " \"validation/max_reward\": max(v3_rewards),\n",
1565
+ " \"validation/v3_vs_v2_pct\": v3_vs_v2,\n",
1566
+ " \"validation/samples\": val_table,\n",
1567
+ "})\n",
1568
+ "wandb.summary.update({\n",
1569
+ " \"validation/mean_reward\": v3_mean,\n",
1570
+ " \"validation/v3_vs_v2_pct\": v3_vs_v2,\n",
1571
+ "})\n",
1572
+ "\n",
1573
+ "# ── Close the W&B run — all outputs are now persisted ────────────────────────\n",
1574
+ "wandb.finish()\n",
1575
+ "print(f\"\\n✓ W&B run finalized — all outputs saved\")"
1576
+ ]
1577
+ },
1578
+ {
1579
+ "cell_type": "code",
1580
+ "execution_count": null,
1581
+ "metadata": {},
1582
+ "outputs": [],
1583
+ "source": [
1584
+ "FastLanguageModel.for_inference(model)\n",
1585
+ "\n",
1586
+ "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1587
+ "\n",
1588
+ "test_prompts = [\n",
1589
+ " {\"role\": \"user\", \"content\": (\n",
1590
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1591
+ " \"nota=2/5 | status=delivered\\ntítulo: decepcionado\\n\"\n",
1592
+ " \"texto: Produto veio com defeito e o vendedor não respondeu.\\n\\n\"\n",
1593
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1594
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1595
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1596
+ " )},\n",
1597
+ " {\"role\": \"user\", \"content\": (\n",
1598
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1599
+ " \"nota=5/5 | status=delivered\\ntítulo: adorei!\\n\"\n",
1600
+ " \"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1601
+ " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1602
+ " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1603
+ " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1604
+ " )},\n",
1605
+ " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
1606
+ " {\"role\": \"user\", \"content\": \"Analise a retenção de clientes afetados por product_quality.\"},\n",
1607
+ " {\"role\": \"user\", \"content\": (\n",
1608
+ " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1609
+ " \"- Reclamação: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1610
+ " \"Este cliente deve receber uma notificação de reengajamento?\"\n",
1611
+ " )},\n",
1612
+ " {\"role\": \"user\", \"content\": \"Compare a satisfação de clientes em SP vs RJ.\"},\n",
1613
+ " {\"role\": \"user\", \"content\": (\n",
1614
+ " \"Crie uma notificação push de reengajamento para um cliente em SP \"\n",
1615
+ " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1616
+ " )},\n",
1617
+ "]\n",
1618
+ "\n",
1619
+ "print(\"=\" * 70)\n",
1620
+ "print(\"GRPO v3 Validation\")\n",
1621
+ "print(\"=\" * 70)\n",
1622
+ "\n",
1623
+ "v3_rewards = []\n",
1624
+ "for i, prompt in enumerate(test_prompts):\n",
1625
+ " messages = [system_msg, prompt]\n",
1626
+ " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1627
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1628
+ "\n",
1629
+ " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)\n",
1630
+ " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
1631
+ " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1632
+ "\n",
1633
+ " reward = commerce_reward_fn([response], [text])[0]\n",
1634
+ " v3_rewards.append(reward)\n",
1635
+ " answer = strip_think(response)\n",
1636
+ " task = _classify_task_type(prompt[\"content\"])\n",
1637
+ " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1638
+ "\n",
1639
+ " print(f\"\\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---\")\n",
1640
+ " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1641
+ " print(f\"Answer: {answer[:400]}\")\n",
1642
+ "\n",
1643
+ "print(f\"\\n{'='*70}\")\n",
1644
+ "print(f\"v3 Validation Summary\")\n",
1645
+ "print(f\"{'='*70}\")\n",
1646
+ "print(f\" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1647
+ "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1648
+ "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1649
+ "print()\n",
1650
+ "print(f\" Comparison to baselines:\")\n",
1651
+ "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1652
+ "print(f\" GRPO v2 validation: mean=0.54\")\n",
1653
+ "print(f\" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1654
+ "v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100\n",
1655
+ "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")"
1656
+ ]
1657
+ }
1658
+ ],
1659
+ "metadata": {
1660
+ "kernelspec": {
1661
+ "display_name": "Python 3 (ipykernel)",
1662
+ "language": "python",
1663
+ "name": "python3"
1664
+ },
1665
+ "language_info": {
1666
+ "file_extension": ".py",
1667
+ "mimetype": "text/x-python",
1668
+ "name": "python",
1669
+ "version": "3.10.0"
1670
+ }
1671
+ },
1672
+ "nbformat": 4,
1673
+ "nbformat_minor": 5
1674
+ }