rtferraz commited on
Commit
1d514ac
·
1 Parent(s): 0f39df7

apply v3 task-aware thinking controls and delete deprecated notebook

Browse files

- Replaced monolithic SYSTEM_PT with task-specific system prompts to guide verbosity and thinking mode per task.

- Integrated reward_think_efficiency in the reward function dispatch to penalize bloated thinking depending on task budgets.

- Added dynamic system prompt injection into calibration, dataset preparation, and validation loops.

- Deleted deprecated notebooks/DEPRECATED_grpo_vertex_v3.ipynb.

notebooks/DEPRECATED_grpo_vertex_v3.ipynb DELETED
@@ -1,1517 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 5,
4
- "metadata": {
5
- "kernelspec": {
6
- "display_name": "Python 3 (ipykernel)",
7
- "language": "python",
8
- "name": "python3"
9
- },
10
- "language_info": {
11
- "name": "python",
12
- "version": "3.10.0",
13
- "mimetype": "text/x-python",
14
- "file_extension": ".py"
15
- }
16
- },
17
- "cells": [
18
- {
19
- "cell_type": "markdown",
20
- "metadata": {},
21
- "source": [
22
- "# Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)\n",
23
- "\n",
24
- "**v3 changes over v2 — grounded in published research:**\n",
25
- "\n",
26
- "| Change | v2 Value | v3 Value | Paper Reference |\n",
27
- "|--------|----------|----------|----------------|\n",
28
- "| Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |\n",
29
- "| Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |\n",
30
- "| Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |\n",
31
- "| Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |\n",
32
- "| β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |\n",
33
- "| Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |\n",
34
- "| Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |\n",
35
- "| Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |\n",
36
- "| Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |\n",
37
- "| Early stopping patience | 10 | **15** | More runway for longer completions |\n",
38
- "| Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |\n",
39
- "| Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |\n",
40
- "| General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |\n",
41
- "\n",
42
- "**Prerequisites:**\n",
43
- "- Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`\n",
44
- "- Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`\n",
45
- "- **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix\n",
46
- "\n",
47
- "**Hardware:** L4 (24GB), PyTorch kernel, bf16 supported\n",
48
- "\n",
49
- "---\n",
50
- "\n",
51
- "## Cell 1: Dependencies\n",
52
- "\n",
53
- "Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:"
54
- ]
55
- },
56
- {
57
- "cell_type": "code",
58
- "execution_count": null,
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "# Cell 1a — Nuke everything ML-related\n",
63
- "!pip uninstall -y torch torchvision torchaudio \\\n",
64
- " unsloth unsloth-zoo \\\n",
65
- " trl transformers peft accelerate \\\n",
66
- " bitsandbytes vllm vllm-flash-attn \\\n",
67
- " datasets tokenizers safetensors huggingface-hub \\\n",
68
- " wandb xformers triton \\\n",
69
- " cuda-bindings cuda-python \\\n",
70
- " sentencepiece protobuf \\\n",
71
- " 2>/dev/null"
72
- ]
73
- },
74
- {
75
- "cell_type": "code",
76
- "execution_count": null,
77
- "metadata": {},
78
- "outputs": [],
79
- "source": [
80
- "# Cell 1b — Kill any stragglers\n",
81
- "!pip freeze | grep -iE \"torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate\" | xargs pip uninstall -y 2>/dev/null"
82
- ]
83
- },
84
- {
85
- "cell_type": "code",
86
- "execution_count": null,
87
- "metadata": {},
88
- "outputs": [],
89
- "source": [
90
- "# Cell 1c — Purge cache\n",
91
- "!pip cache purge"
92
- ]
93
- },
94
- {
95
- "cell_type": "markdown",
96
- "metadata": {},
97
- "source": [
98
- "**⚠️ Restart kernel again**, then:"
99
- ]
100
- },
101
- {
102
- "cell_type": "code",
103
- "execution_count": null,
104
- "metadata": {},
105
- "outputs": [],
106
- "source": [
107
- "# Cell 1d — Clean install, correct order\n",
108
- "!pip install \"unsloth\""
109
- ]
110
- },
111
- {
112
- "cell_type": "code",
113
- "execution_count": null,
114
- "metadata": {},
115
- "outputs": [],
116
- "source": [
117
- "# Cell 1e — Pin TRL (Unsloth may pull a different version)\n",
118
- "!pip install \"trl==0.24.0\" --no-deps"
119
- ]
120
- },
121
- {
122
- "cell_type": "code",
123
- "execution_count": null,
124
- "metadata": {},
125
- "outputs": [],
126
- "source": [
127
- "# Cell 1f — Extra deps\n",
128
- "!pip install \"rich\" \"wandb\""
129
- ]
130
- },
131
- {
132
- "cell_type": "markdown",
133
- "metadata": {},
134
- "source": [
135
- "---\n",
136
- "\n",
137
- "## Cell 2: Hello World — GPU + Unsloth Verification"
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "metadata": {},
144
- "outputs": [],
145
- "source": [
146
- "import torch\n",
147
- "\n",
148
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
149
- "print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
150
- "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\n",
151
- "print(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n",
152
- "\n",
153
- "from unsloth import FastLanguageModel\n",
154
- "print(\"\\n✓ Unsloth loaded successfully\")\n",
155
- "\n",
156
- "import trl\n",
157
- "print(f\"✓ TRL version: {trl.__version__}\")\n",
158
- "\n",
159
- "import transformers\n",
160
- "print(f\"✓ Transformers version: {transformers.__version__}\")"
161
- ]
162
- },
163
- {
164
- "cell_type": "markdown",
165
- "metadata": {},
166
- "source": [
167
- "---\n",
168
- "\n",
169
- "## Cell 3: Config + Constants"
170
- ]
171
- },
172
- {
173
- "cell_type": "code",
174
- "execution_count": null,
175
- "metadata": {},
176
- "outputs": [],
177
- "source": [
178
- "import os\n",
179
- "os.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\n",
180
- "\n",
181
- "import json\n",
182
- "import re\n",
183
- "import time\n",
184
- "import random\n",
185
- "import gc\n",
186
- "from pathlib import Path\n",
187
- "\n",
188
- "# ══════════════════════════════════════════════════════════════════════════════\n",
189
- "# v3 CONFIG — Every change is annotated with paper reference\n",
190
- "# ══════════════════════════════════════════════════════════════════════════════\n",
191
- "\n",
192
- "MODEL_ID = \"Polygl0t/Tucano2-qwen-3.7B-Think\"\n",
193
- "MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt\n",
194
- "\n",
195
- "# ── Paths ─────────────────────────────────────────────────────────────────────\n",
196
- "DATA_DIR = Path(\"/home/jupyter/tucano2/data\")\n",
197
- "MODELS_DIR = Path(\"/home/jupyter/tucano2/models\")\n",
198
- "SFT_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-sft\"\n",
199
- "GRPO_ADAPTER_DIR = MODELS_DIR / \"tucano2-commerce-grpo-v3\" # v3: separate dir from v2\n",
200
- "CHECKPOINT_DIR = GRPO_ADAPTER_DIR / \"checkpoints\"\n",
201
- "\n",
202
- "# ── Training data ─────────────────────────────────────────────────────────────\n",
203
- "GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)\n",
204
- "GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)\n",
205
- "\n",
206
- "# ── Valid enums for reward scoring (unchanged from v2) ────────────────────────\n",
207
- "VALID_SENTIMENTS = {\"positive\", \"negative\", \"neutral\"}\n",
208
- "VALID_CATEGORIES = {\n",
209
- " \"delivery_delay\", \"product_quality\", \"product_not_received\",\n",
210
- " \"wrong_product\", \"seller_communication\", \"app_issue\",\n",
211
- " \"price_value\", \"other\", \"none\",\n",
212
- "}\n",
213
- "VALID_CHURN = {\"low\", \"medium\", \"high\"}\n",
214
- "VALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\n",
215
- "EXTRACTION_FIELDS = [\n",
216
- " \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n",
217
- " \"product_issue\", \"seller_issue\", \"main_complaint\",\n",
218
- " \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n",
219
- "]\n",
220
- "\n",
221
- "SYSTEM_PT = (\n",
222
- " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
223
- " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\"\n",
224
- ")\n",
225
- "\n",
226
- "# ══════════════════════════════════════════════════════════════════════════════\n",
227
- "# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)\n",
228
- "# ══════════════════════════════════════════════════════════════════════════════\n",
229
- "\n",
230
- "# ── Core GRPO params ──────────────────────────────────────────────────────────\n",
231
- "BATCH_SIZE = 4\n",
232
- "GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)\n",
233
- " # With G=4: steps = prompts × 4 / 4 = prompts per epoch\n",
234
- "NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions\n",
235
- " # MC-GRPO (2601.22582): G=4 works if noise is mitigated\n",
236
- "SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias\n",
237
- "\n",
238
- "# ── v3 CRITICAL FIXES ────────────────────────────────────────────────────────\n",
239
- "\n",
240
- "# FIX 1: Temperature — prevent entropy collapse\n",
241
- "# v2 had 0.8. All published GRPO papers use 1.0.\n",
242
- "# Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance\n",
243
- "TEMPERATURE = 1.0\n",
244
- "\n",
245
- "# FIX 2: Completion length — remove the ceiling\n",
246
- "# v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.\n",
247
- "# Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient\n",
248
- "MAX_COMPLETION_LENGTH = 4096\n",
249
- "\n",
250
- "# FIX 3: Learning rate — more aggressive\n",
251
- "# v2: clip_ratio=0 on all steps → updates were too small to matter\n",
252
- "# Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.\n",
253
- "# We go 2× since v2 showed zero clipping (model can absorb stronger push)\n",
254
- "LEARNING_RATE = 2e-6\n",
255
- "\n",
256
- "# FIX 4: β = 0 (no KL penalty)\n",
257
- "# Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards\n",
258
- "# v2 used implicit KL through default β — we explicitly disable it\n",
259
- "BETA = 0.0\n",
260
- "\n",
261
- "# ── Training schedule ─────────────────────────────────────────────────────────\n",
262
- "NUM_EPOCHS = 1\n",
263
- "MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed\n",
264
- " # With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch\n",
265
- " # MAX_STEPS=500 < 1 epoch — early stopping or manual extension\n",
266
- "\n",
267
- "# ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────\n",
268
- "EVAL_SPLIT_RATIO = 0.15\n",
269
- "EVAL_STEPS = 10\n",
270
- "EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway\n",
271
- "EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains\n",
272
- "SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again\n",
273
- "SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)\n",
274
- "WANDB_PROJECT = \"tucano2-commerce\"\n",
275
- "\n",
276
- "# ── Eval callback ─────────────────────────────────────────────────────────────\n",
277
- "EVAL_MAX_SAMPLES = 5\n",
278
- "EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)\n",
279
- "EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)\n",
280
- "\n",
281
- "# ── Backend ───────────────────────────────────────────────────────────────────\n",
282
- "USE_VLLM = False\n",
283
- "\n",
284
- "# ── v3: Zero-advantage noise injection ────────────────────────────────────────\n",
285
- "# Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training\n",
286
- "# When all G completions get identical rewards, the advantage is undefined.\n",
287
- "# Noise injection breaks ties without corrupting the signal.\n",
288
- "ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups\n",
289
- "\n",
290
- "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n",
291
- "\n",
292
- "# ── Version assertion ─────────────────────────────────────────────────────────\n",
293
- "import trl as _trl\n",
294
- "assert _trl.__version__ == \"0.24.0\", (\n",
295
- " f\"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\\n\"\n",
296
- " \"Verify that GRPOTrainer._generate() still exists before proceeding.\"\n",
297
- ")\n",
298
- "\n",
299
- "print(\"✓ v3 Config loaded\")\n",
300
- "print(f\" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})\")\n",
301
- "print(f\" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})\")\n",
302
- "print(f\" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}\")\n",
303
- "print(f\" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}\")\n",
304
- "print(f\" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)\")\n",
305
- "print(f\" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}\")\n",
306
- "print(f\"✓ TRL {_trl.__version__} verified\")\n",
307
- "\n",
308
- "# ══════════════════════════════════════════════════════════════════════════════\n",
309
- "# v3 VRAM BUDGET (L4 24GB)\n",
310
- "# ══════════════════════════════════════════════════════════════════════════════\n",
311
- "# Model (NF4): ~3.5 GB\n",
312
- "# KV Cache (8192 seq): ~3.0 GB\n",
313
- "# Activations: ~4.0 GB\n",
314
- "# Optimizer states: ~3.0 GB\n",
315
- "# Generations (4×4096): ~8.0 GB\n",
316
- "# ─────────────────────────────────\n",
317
- "# Estimated total: ~21.5 GB\n",
318
- "# Headroom: ~2.5 GB\n",
319
- "#\n",
320
- "# If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.\n",
321
- "# Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.\n",
322
- "# ══════════════════════════════════════════════════════════════════════════════"
323
- ]
324
- },
325
- {
326
- "cell_type": "markdown",
327
- "metadata": {},
328
- "source": [
329
- "---\n",
330
- "\n",
331
- "## Cell 4: Load SFT Adapter"
332
- ]
333
- },
334
- {
335
- "cell_type": "code",
336
- "execution_count": null,
337
- "metadata": {},
338
- "outputs": [],
339
- "source": [
340
- "print(\"Loading SFT adapter...\")\n",
341
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
342
- " model_name=str(SFT_ADAPTER_DIR),\n",
343
- " max_seq_length=MAX_SEQ_LENGTH,\n",
344
- " load_in_4bit=True,\n",
345
- " dtype=None,\n",
346
- ")\n",
347
- "\n",
348
- "if tokenizer.pad_token is None:\n",
349
- " tokenizer.pad_token = tokenizer.eos_token\n",
350
- "\n",
351
- "# Load chat template from base model (SFT adapter doesn't save it)\n",
352
- "from transformers import AutoTokenizer\n",
353
- "base_tok = AutoTokenizer.from_pretrained(MODEL_ID)\n",
354
- "tokenizer.chat_template = base_tok.chat_template\n",
355
- "del base_tok\n",
356
- "\n",
357
- "# v2: Force KV cache — Unsloth patching may reset this\n",
358
- "model.config.use_cache = True\n",
359
- "model.generation_config.use_cache = True\n",
360
- "\n",
361
- "print(f\"✓ Model loaded on {model.device}\")\n",
362
- "print(f\" use_cache: {model.config.use_cache}\")\n",
363
- "print(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n",
364
- "print(f\" Chat template: {tokenizer.chat_template[:50]}...\")"
365
- ]
366
- },
367
- {
368
- "cell_type": "markdown",
369
- "metadata": {},
370
- "source": [
371
- "---\n",
372
- "\n",
373
- "## Cell 5: Single Inference Test\n",
374
- "\n",
375
- "**Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?"
376
- ]
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": null,
381
- "metadata": {},
382
- "outputs": [],
383
- "source": [
384
- "FastLanguageModel.for_inference(model)\n",
385
- "\n",
386
- "test_msgs = [\n",
387
- " {\"role\": \"system\", \"content\": SYSTEM_PT},\n",
388
- " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
389
- "]\n",
390
- "text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)\n",
391
- "inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
392
- "\n",
393
- "t0 = time.time()\n",
394
- "outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
395
- "elapsed = time.time() - t0\n",
396
- "\n",
397
- "response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
398
- "gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
399
- "\n",
400
- "print(f\"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)\")\n",
401
- "print(f\"Response length: {len(response)} chars, {gen_tokens} tokens\")\n",
402
- "print(f\"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}\") # v3: should NOT hit ceiling with 4096\n",
403
- "print(f\"closed_think: {'</think>' in response}\")\n",
404
- "print(f\"\\n{'='*60}\")\n",
405
- "print(response[:800])"
406
- ]
407
- },
408
- {
409
- "cell_type": "markdown",
410
- "metadata": {},
411
- "source": [
412
- "---\n",
413
- "\n",
414
- "## Cell 5b: KV Cache Diagnostic"
415
- ]
416
- },
417
- {
418
- "cell_type": "code",
419
- "execution_count": null,
420
- "metadata": {},
421
- "outputs": [],
422
- "source": [
423
- "import time\n",
424
- "FastLanguageModel.for_inference(model)\n",
425
- "\n",
426
- "_kv_msgs = [{\"role\": \"system\", \"content\": SYSTEM_PT},\n",
427
- " {\"role\": \"user\", \"content\": \"Qual a categoria de reclamação mais frequente?\"}]\n",
428
- "_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n",
429
- "_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n",
430
- "\n",
431
- "_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\n",
432
- "with torch.no_grad():\n",
433
- " for _step in range(50):\n",
434
- " _t0 = time.time()\n",
435
- " seq_len = _generated.shape[1]\n",
436
- " if _past is None:\n",
437
- " _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n",
438
- " else:\n",
439
- " _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n",
440
- " _out = model(\n",
441
- " input_ids=_generated[:, -1:] if _past else _generated,\n",
442
- " position_ids=_position_ids,\n",
443
- " attention_mask=torch.ones(1, seq_len, device=model.device),\n",
444
- " past_key_values=_past,\n",
445
- " use_cache=True,\n",
446
- " return_dict=True,\n",
447
- " )\n",
448
- " _past = _out.past_key_values\n",
449
- " _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n",
450
- " _generated = torch.cat([_generated, _next], dim=1)\n",
451
- " _token_times.append(time.time() - _t0)\n",
452
- "\n",
453
- "_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\n",
454
- "print(f\"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\n",
455
- "print(f\"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\n",
456
- "print(f\"Ratio last/first: {_ratio:.1f}x\")\n",
457
- "if _ratio < 3:\n",
458
- " print(\"✓ KV cache is working correctly\")\n",
459
- "elif _ratio < 6:\n",
460
- " print(\"⚠ KV cache may be degraded — check model.config.use_cache\")\n",
461
- "else:\n",
462
- " print(\"✗ KV cache BROKEN — GRPO generation will be catastrophically slow.\")\n",
463
- "\n",
464
- "del _past, _generated, _kv_inputs, _token_times, _out\n",
465
- "gc.collect()\n",
466
- "if torch.cuda.is_available(): torch.cuda.empty_cache()"
467
- ]
468
- },
469
- {
470
- "cell_type": "markdown",
471
- "metadata": {},
472
- "source": [
473
- "---\n",
474
- "\n",
475
- "## Cell 6: Reward Functions v3\n",
476
- "\n",
477
- "**v3 changes:**\n",
478
- "- Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)\n",
479
- "- Zero-advantage noise injection (Skywork-OR1, 2505.22312)\n",
480
- "- Extraction reward redesigned for completion-length-friendly scoring"
481
- ]
482
- },
483
- {
484
- "cell_type": "code",
485
- "execution_count": null,
486
- "metadata": {},
487
- "outputs": [],
488
- "source": [
489
- "def strip_think(text: str) -> str:\n",
490
- " \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n",
491
- " return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n",
492
- "\n",
493
- "\n",
494
- "def has_think_block(text: str) -> bool:\n",
495
- " \"\"\"Check if text contains a non-empty <think> block.\"\"\"\n",
496
- " return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n",
497
- "\n",
498
- "\n",
499
- "def _classify_task_type(prompt_text: str) -> str:\n",
500
- " \"\"\"Classify prompt into task type by keywords.\"\"\"\n",
501
- " p = prompt_text.lower()\n",
502
- " if \"retorne um objeto json\" in p or \"extraia dados\" in p:\n",
503
- " return \"extraction\"\n",
504
- " elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
505
- " return \"push\"\n",
506
- " elif \"perfil do cliente\" in p:\n",
507
- " return \"insights\"\n",
508
- " else:\n",
509
- " return \"sql_qa\"\n",
510
- "\n",
511
- "\n",
512
- "def _json_similarity(text: str) -> float:\n",
513
- " \"\"\"Rough heuristic: how JSON-like is this text? 0.0 to 1.0.\"\"\"\n",
514
- " text = text.strip()\n",
515
- " if not text:\n",
516
- " return 0.0\n",
517
- " score = 0.0\n",
518
- " if text.startswith(\"{\") and text.endswith(\"}\"):\n",
519
- " score += 0.5\n",
520
- " if '\"' in text:\n",
521
- " score += 0.2\n",
522
- " if \":\" in text:\n",
523
- " score += 0.2\n",
524
- " if \",\" in text:\n",
525
- " score += 0.1\n",
526
- " return min(score, 1.0)\n",
527
- "\n",
528
- "\n",
529
- "def _string_similarity(a: str, b: str) -> float:\n",
530
- " \"\"\"Simple Jaccard-like similarity for short strings. 0.0 to 1.0.\"\"\"\n",
531
- " if not a or not b:\n",
532
- " return 0.0\n",
533
- " a_set = set(a.split())\n",
534
- " b_set = set(b.split())\n",
535
- " intersection = len(a_set & b_set)\n",
536
- " union = len(a_set | b_set)\n",
537
- " return intersection / union if union > 0 else 0.0\n",
538
- "\n",
539
- "\n",
540
- "# ══════════════════════════════════════════════════════════════════════════════\n",
541
- "# v3 STAGED REWARD DESIGN\n",
542
- "# Reference: Reasoning-SQL (2503.23157) §3.2\n",
543
- "#\n",
544
- "# Each reward function scores THREE stages independently:\n",
545
- "# Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?\n",
546
- "# Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?\n",
547
- "# Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?\n",
548
- "#\n",
549
- "# Format rewards converge first (easy to learn), which stabilizes training\n",
550
- "# and enables the model to then learn harder task-specific skills.\n",
551
- "# ══════════════════════════════════════════════════════════════════════════════\n",
552
- "\n",
553
- "\n",
554
- "def reward_extraction(completion: str) -> float:\n",
555
- " \"\"\"Staged reward for structured extraction (max 1.0).\"\"\"\n",
556
- " answer = strip_think(completion)\n",
557
- "\n",
558
- " # ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────\n",
559
- " r_format = 0.0\n",
560
- " if has_think_block(completion):\n",
561
- " r_format += 0.1 # Used reasoning\n",
562
- "\n",
563
- " try:\n",
564
- " data = json.loads(answer)\n",
565
- " if isinstance(data, dict):\n",
566
- " r_format += 0.1 # Valid JSON object\n",
567
- " except (json.JSONDecodeError, TypeError):\n",
568
- " r_format += 0.05 * _json_similarity(answer)\n",
569
- " return min(r_format, 0.2)\n",
570
- "\n",
571
- " if not isinstance(data, dict):\n",
572
- " return min(r_format, 0.2)\n",
573
- "\n",
574
- " # ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────\n",
575
- " r_partial = 0.0\n",
576
- "\n",
577
- " present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n",
578
- " r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))\n",
579
- "\n",
580
- " type_checks = 0\n",
581
- " type_total = 0\n",
582
- " for field in EXTRACTION_FIELDS:\n",
583
- " if field not in data:\n",
584
- " continue\n",
585
- " type_total += 1\n",
586
- " val = data[field]\n",
587
- " if field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
588
- " if isinstance(val, bool):\n",
589
- " type_checks += 1\n",
590
- " elif field in (\"sentiment_score\",):\n",
591
- " if isinstance(val, (int, float)):\n",
592
- " type_checks += 1\n",
593
- " elif field in (\"main_complaint\", \"sentiment\", \"complaint_category\", \"churn_risk\", \"repeat_intent\"):\n",
594
- " if isinstance(val, str):\n",
595
- " type_checks += 1\n",
596
- " if type_total > 0:\n",
597
- " r_partial += 0.15 * (type_checks / type_total)\n",
598
- "\n",
599
- " # ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────\n",
600
- " r_task = 0.0\n",
601
- " cat_checks = 0\n",
602
- " cat_total = 0\n",
603
- "\n",
604
- " checks = [\n",
605
- " (\"sentiment\", lambda v: v in VALID_SENTIMENTS),\n",
606
- " (\"complaint_category\", lambda v: v in VALID_CATEGORIES),\n",
607
- " (\"churn_risk\", lambda v: v in VALID_CHURN),\n",
608
- " (\"repeat_intent\", lambda v: v in VALID_REPEAT),\n",
609
- " (\"sentiment_score\", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),\n",
610
- " ]\n",
611
- " for field, validator in checks:\n",
612
- " cat_total += 1\n",
613
- " if field in data and validator(data[field]):\n",
614
- " cat_checks += 1\n",
615
- "\n",
616
- " for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
617
- " cat_total += 1\n",
618
- " if bool_field in data and isinstance(data[bool_field], bool):\n",
619
- " cat_checks += 1\n",
620
- "\n",
621
- " if cat_total > 0:\n",
622
- " r_task += 0.35 * (cat_checks / cat_total)\n",
623
- "\n",
624
- " if \"main_complaint\" in data and isinstance(data[\"main_complaint\"], str):\n",
625
- " complaint = data[\"main_complaint\"].strip()\n",
626
- " if len(complaint) > 10:\n",
627
- " r_task += 0.15\n",
628
- "\n",
629
- " return min(r_format + r_partial + r_task, 1.0)\n",
630
- "\n",
631
- "\n",
632
- "def reward_sql_qa(completion: str) -> float:\n",
633
- " \"\"\"Staged reward for SQL Q&A (max 1.0).\"\"\"\n",
634
- " answer = strip_think(completion)\n",
635
- "\n",
636
- " # ── Stage 1: FORMAT (max 0.2)\n",
637
- " r_format = 0.0\n",
638
- " if has_think_block(completion):\n",
639
- " r_format += 0.1\n",
640
- " if \"```\" in answer or re.search(r\"SELECT|FROM\", answer, re.IGNORECASE):\n",
641
- " r_format += 0.1\n",
642
- "\n",
643
- " # ── Stage 2: PARTIAL (max 0.3)\n",
644
- " r_partial = 0.0\n",
645
- " sql_keywords = r\"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING\"\n",
646
- " matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))\n",
647
- " r_partial += min(0.15, 0.03 * matches)\n",
648
- " numbers = re.findall(r\"\\d+(?:[.,]\\d+)?\", answer)\n",
649
- " r_partial += min(0.15, 0.03 * len(numbers))\n",
650
- "\n",
651
- " # ── Stage 3: TASK (max 0.5)\n",
652
- " r_task = 0.0\n",
653
- " length = len(answer)\n",
654
- " if 50 <= length <= 600:\n",
655
- " r_task += 0.25\n",
656
- " elif length > 0:\n",
657
- " r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)\n",
658
- " explanation_markers = [\"para \", \"porque\", \"resultado\", \"mostra\", \"indica\", \"análise\"]\n",
659
- " expl_matches = sum(1 for w in explanation_markers if w in answer.lower())\n",
660
- " r_task += min(0.25, 0.05 * expl_matches)\n",
661
- "\n",
662
- " return min(r_format + r_partial + r_task, 1.0)\n",
663
- "\n",
664
- "\n",
665
- "def reward_insights(completion: str) -> float:\n",
666
- " \"\"\"Staged reward for insights (max 1.0).\"\"\"\n",
667
- " answer = strip_think(completion)\n",
668
- "\n",
669
- " # ── Stage 1: FORMAT (max 0.2)\n",
670
- " r_format = 0.0\n",
671
- " if has_think_block(completion):\n",
672
- " r_format += 0.1\n",
673
- " structure_marks = len(re.findall(r\"^[-•*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n",
674
- " r_format += min(0.1, 0.02 * structure_marks)\n",
675
- "\n",
676
- " # ── Stage 2: PARTIAL (max 0.3)\n",
677
- " r_partial = 0.0\n",
678
- " length = len(answer)\n",
679
- " if 100 <= length <= 1200:\n",
680
- " r_partial += 0.15\n",
681
- " elif length > 0:\n",
682
- " r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)\n",
683
- " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto\", answer, re.IGNORECASE)\n",
684
- " r_partial += min(0.15, 0.01 * len(pt_markers))\n",
685
- "\n",
686
- " # ── Stage 3: TASK (max 0.5)\n",
687
- " r_task = 0.0\n",
688
- " action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n",
689
- " \"priorizar\", \"investir\", \"otimizar\", \"estratégi\", \"suger\",\n",
690
- " \"consider\", \"ação\", \"plano\"]\n",
691
- " matches = sum(1 for w in action_words if w in answer.lower())\n",
692
- " r_task += min(0.3, 0.06 * matches)\n",
693
- " data_refs = len(re.findall(r\"\\d+%|R\\$\\s*\\d|média|percentual|comparad|taxa\", answer, re.IGNORECASE))\n",
694
- " r_task += min(0.2, 0.04 * data_refs)\n",
695
- "\n",
696
- " return min(r_format + r_partial + r_task, 1.0)\n",
697
- "\n",
698
- "\n",
699
- "def reward_push(completion: str) -> float:\n",
700
- " \"\"\"Staged reward for push notifications (max 1.0).\"\"\"\n",
701
- " answer = strip_think(completion)\n",
702
- " if not answer:\n",
703
- " return 0.0\n",
704
- "\n",
705
- " # ── Stage 1: FORMAT (max 0.2)\n",
706
- " r_format = 0.0\n",
707
- " if has_think_block(completion):\n",
708
- " r_format += 0.05\n",
709
- " length = len(answer)\n",
710
- " if length <= 160:\n",
711
- " r_format += 0.15\n",
712
- " elif length <= 300:\n",
713
- " r_format += 0.1\n",
714
- " else:\n",
715
- " r_format += 0.05\n",
716
- "\n",
717
- " # ── Stage 2: PARTIAL (max 0.3)\n",
718
- " r_partial = 0.0\n",
719
- " pt_markers = re.findall(r\"[ãçéêóúâõ]|você|para|como|seu|sua\", answer, re.IGNORECASE)\n",
720
- " r_partial += min(0.15, 0.02 * len(pt_markers))\n",
721
- " if re.search(r\"[!?]|[\\U0001F600-\\U0001F64F]|[\\U0001F300-\\U0001F5FF]\", answer):\n",
722
- " r_partial += 0.05\n",
723
- " if len(answer.split()) >= 5:\n",
724
- " r_partial += 0.1\n",
725
- "\n",
726
- " # ── Stage 3: TASK (max 0.5)\n",
727
- " r_task = 0.0\n",
728
- " if length <= 120:\n",
729
- " r_task += 0.25\n",
730
- " else:\n",
731
- " r_task += 0.25 * max(0, 1 - (length - 120) / 120)\n",
732
- " generic_phrases = [\n",
733
- " \"olá! como podemos ajudar\", \"obrigado pela sua compra\",\n",
734
- " \"seu pedido foi confirmado\", \"agradecemos sua preferência\",\n",
735
- " ]\n",
736
- " max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)\n",
737
- " r_task += 0.25 * (1 - max_similarity)\n",
738
- "\n",
739
- " return min(r_format + r_partial + r_task, 1.0)\n",
740
- "\n",
741
- "\n",
742
- "def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n",
743
- " \"\"\"\n",
744
- " Master reward function v3: dispatches by task type + zero-advantage noise.\n",
745
- " \"\"\"\n",
746
- " rewards = []\n",
747
- " for completion, prompt in zip(completions, prompts):\n",
748
- " if isinstance(completion, list):\n",
749
- " comp_text = completion[-1][\"content\"] if completion else \"\"\n",
750
- " else:\n",
751
- " comp_text = str(completion)\n",
752
- "\n",
753
- " if isinstance(prompt, list):\n",
754
- " prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
755
- " else:\n",
756
- " prompt_text = str(prompt)\n",
757
- "\n",
758
- " task = _classify_task_type(prompt_text)\n",
759
- "\n",
760
- " if task == \"extraction\":\n",
761
- " rewards.append(reward_extraction(comp_text))\n",
762
- " elif task == \"sql_qa\":\n",
763
- " rewards.append(reward_sql_qa(comp_text))\n",
764
- " elif task == \"insights\":\n",
765
- " rewards.append(reward_insights(comp_text))\n",
766
- " elif task == \"push\":\n",
767
- " rewards.append(reward_push(comp_text))\n",
768
- " else:\n",
769
- " r = 0.15 if has_think_block(comp_text) else 0.0\n",
770
- " r += 0.2 if comp_text.strip() else 0.0\n",
771
- " rewards.append(r)\n",
772
- "\n",
773
- " # ── v3: Zero-advantage noise injection ────────────────────────────────────\n",
774
- " if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:\n",
775
- " for i in range(0, len(rewards), NUM_GENERATIONS):\n",
776
- " group = rewards[i:i+NUM_GENERATIONS]\n",
777
- " if len(group) < 2:\n",
778
- " continue\n",
779
- " if max(group) - min(group) < 0.001:\n",
780
- " for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):\n",
781
- " rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)\n",
782
- "\n",
783
- " return rewards\n",
784
- "\n",
785
- "\n",
786
- "print(\"✓ v3 Reward functions defined (staged: format → partial → task)\")"
787
- ]
788
- },
789
- {
790
- "cell_type": "markdown",
791
- "metadata": {},
792
- "source": [
793
- "---\n",
794
- "\n",
795
- "## Cell 7: Reward Calibration\n",
796
- "\n",
797
- "**Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38)."
798
- ]
799
- },
800
- {
801
- "cell_type": "code",
802
- "execution_count": null,
803
- "metadata": {},
804
- "outputs": [],
805
- "source": [
806
- "train_path = DATA_DIR / \"pairs\" / \"train.jsonl\"\n",
807
- "\n",
808
- "by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
809
- "with open(train_path) as f:\n",
810
- " for line in f:\n",
811
- " row = json.loads(line)\n",
812
- " convs = row[\"conversations\"]\n",
813
- " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
814
- " if not prompt_msgs:\n",
815
- " continue\n",
816
- " user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
817
- " task = _classify_task_type(user_text)\n",
818
- " by_type[task].append(prompt_msgs)\n",
819
- "\n",
820
- "print(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n",
821
- "\n",
822
- "rng = random.Random(42)\n",
823
- "cal_samples = []\n",
824
- "for task_type in [\"extraction\", \"extraction\", \"sql_qa\", \"sql_qa\", \"insights\", \"insights\", \"push\", \"push\"]:\n",
825
- " cal_samples.append(rng.choice(by_type[task_type]))\n",
826
- "\n",
827
- "FastLanguageModel.for_inference(model)\n",
828
- "print(f\"\\nReward calibration v3 ({len(cal_samples)} samples):\")\n",
829
- "print(\"-\" * 70)\n",
830
- "\n",
831
- "cal_rewards = []\n",
832
- "for i, msgs in enumerate(cal_samples):\n",
833
- " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
834
- " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
835
- " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
836
- " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
837
- " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
838
- "\n",
839
- " r = commerce_reward_fn([response], [text])[0]\n",
840
- " cal_rewards.append(r)\n",
841
- " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
842
- " has_answer = \"</think>\" in response\n",
843
- " answer_preview = strip_think(response)[:100] if has_answer else \"[stuck in <think>]\"\n",
844
- " task = _classify_task_type(text)\n",
845
- " print(f\" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'⚠️ HIT' if hit_ceiling else 'ok':6s} | {answer_preview}\")\n",
846
- "\n",
847
- "print(f\"\\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}\")\n",
848
- "print(f\"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70\")\n",
849
- "print(f\"Variance > 0: {len(set(cal_rewards)) > 1}\")"
850
- ]
851
- },
852
- {
853
- "cell_type": "markdown",
854
- "metadata": {},
855
- "source": [
856
- "---\n",
857
- "\n",
858
- "## Cell 8: Dataset Preparation v3"
859
- ]
860
- },
861
- {
862
- "cell_type": "code",
863
- "execution_count": null,
864
- "metadata": {},
865
- "outputs": [],
866
- "source": [
867
- "from datasets import Dataset\n",
868
- "\n",
869
- "def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,\n",
870
- " general_mix=GENERAL_MIX_RATIO, seed=42):\n",
871
- " rng = random.Random(seed)\n",
872
- "\n",
873
- " train_pools = {}\n",
874
- " eval_records = []\n",
875
- " for task, pool in by_type.items():\n",
876
- " shuffled = pool.copy()\n",
877
- " rng.shuffle(shuffled)\n",
878
- " n_eval = max(1, int(len(shuffled) * eval_ratio))\n",
879
- " eval_records.extend(shuffled[:n_eval])\n",
880
- " train_pools[task] = shuffled[n_eval:]\n",
881
- "\n",
882
- " if n_prompts is None:\n",
883
- " train_records = []\n",
884
- " for task, pool in train_pools.items():\n",
885
- " train_records.extend(pool)\n",
886
- " rng.shuffle(train_records)\n",
887
- " else:\n",
888
- " targets = {\n",
889
- " \"extraction\": int(n_prompts * 0.4),\n",
890
- " \"sql_qa\": int(n_prompts * 0.4),\n",
891
- " \"insights\": int(n_prompts * 0.1),\n",
892
- " \"push\": int(n_prompts * 0.1),\n",
893
- " }\n",
894
- " train_records = []\n",
895
- " for task, target_n in targets.items():\n",
896
- " pool = train_pools[task]\n",
897
- " n = min(target_n, len(pool))\n",
898
- " train_records.extend(rng.sample(pool, n))\n",
899
- " rng.shuffle(train_records)\n",
900
- "\n",
901
- " general_path = DATA_DIR / \"pairs\" / \"general_reasoning.jsonl\"\n",
902
- " if general_mix > 0 and general_path.exists():\n",
903
- " general_records = []\n",
904
- " with open(general_path) as f:\n",
905
- " for line in f:\n",
906
- " row = json.loads(line)\n",
907
- " convs = row[\"conversations\"]\n",
908
- " prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
909
- " if prompt_msgs:\n",
910
- " general_records.append(prompt_msgs)\n",
911
- " n_general = int(len(train_records) * general_mix / (1 - general_mix))\n",
912
- " n_general = min(n_general, len(general_records))\n",
913
- " if n_general > 0:\n",
914
- " train_records.extend(rng.sample(general_records, n_general))\n",
915
- " rng.shuffle(train_records)\n",
916
- " print(f\" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)\")\n",
917
- " elif general_mix > 0:\n",
918
- " print(f\" ⚠️ general_reasoning.jsonl not found — skipping mix\")\n",
919
- "\n",
920
- " task_dist = {}\n",
921
- " for record in train_records:\n",
922
- " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
923
- " task = _classify_task_type(user_text)\n",
924
- " task_dist[task] = task_dist.get(task, 0) + 1\n",
925
- "\n",
926
- " n_domain = len(train_records)\n",
927
- " steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)\n",
928
- "\n",
929
- " print(f\"v3 Dataset split (eval_ratio={eval_ratio}):\")\n",
930
- " print(f\" train : {n_domain} prompts\")\n",
931
- " print(f\" eval : {len(eval_records)} prompts\")\n",
932
- " print(f\" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}\")\n",
933
- " print(f\" steps/epoch: {n_domain} × {NUM_GENERATIONS} / ({BATCH_SIZE} × {GRAD_ACCUM}) = {steps_per_epoch}\")\n",
934
- " print(f\" MAX_STEPS={MAX_STEPS} → {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}\")\n",
935
- "\n",
936
- " train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n",
937
- " eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n",
938
- " return train_ds, eval_ds\n",
939
- "\n",
940
- "\n",
941
- "train_dataset, eval_dataset = prepare_grpo_datasets_v3()\n",
942
- "dataset = train_dataset\n",
943
- "print(f\"\\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}\")"
944
- ]
945
- },
946
- {
947
- "cell_type": "markdown",
948
- "metadata": {},
949
- "source": [
950
- "---\n",
951
- "\n",
952
- "## Cell 9: Smoke Test\n",
953
- "\n",
954
- "**Gate:** Runs 1 step without OOM at new completion length (4096)."
955
- ]
956
- },
957
- {
958
- "cell_type": "code",
959
- "execution_count": null,
960
- "metadata": {},
961
- "outputs": [],
962
- "source": [
963
- "from trl import GRPOConfig, GRPOTrainer\n",
964
- "\n",
965
- "FastLanguageModel.for_training(model)\n",
966
- "\n",
967
- "smoke_config = GRPOConfig(\n",
968
- " output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n",
969
- " num_generations=NUM_GENERATIONS,\n",
970
- " scale_rewards=SCALE_REWARDS,\n",
971
- " max_completion_length=MAX_COMPLETION_LENGTH,\n",
972
- " max_steps=1,\n",
973
- " num_train_epochs=1,\n",
974
- " temperature=TEMPERATURE,\n",
975
- " per_device_train_batch_size=BATCH_SIZE,\n",
976
- " gradient_accumulation_steps=1,\n",
977
- " learning_rate=LEARNING_RATE,\n",
978
- " fp16=False,\n",
979
- " bf16=True,\n",
980
- " logging_steps=1,\n",
981
- " save_steps=999,\n",
982
- " report_to=\"none\",\n",
983
- " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
984
- " seed=42,\n",
985
- " remove_unused_columns=False,\n",
986
- ")\n",
987
- "\n",
988
- "smoke_trainer = GRPOTrainer(\n",
989
- " model=model,\n",
990
- " reward_funcs=commerce_reward_fn,\n",
991
- " args=smoke_config,\n",
992
- " train_dataset=dataset,\n",
993
- " tokenizer=tokenizer,\n",
994
- ")\n",
995
- "\n",
996
- "t0 = time.time()\n",
997
- "smoke_trainer.train()\n",
998
- "step_time = time.time() - t0\n",
999
- "\n",
1000
- "print(f\"\\n✓ Smoke test passed!\")\n",
1001
- "print(f\" Step time (grad_accum=1): {step_time:.0f}s\")\n",
1002
- "print(f\" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s\")\n",
1003
- "print(f\" VRAM peak: {torch.cuda.max_memory_allocated()/1e9:.1f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB\")\n",
1004
- "\n",
1005
- "vram_used = torch.cuda.max_memory_allocated() / 1e9\n",
1006
- "vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
1007
- "if vram_used > vram_total * 0.95:\n",
1008
- " print(f\"\\n⚠️ VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM\")\n",
1009
- " print(f\" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072\")\n",
1010
- " print(f\" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)\")\n",
1011
- "\n",
1012
- "del smoke_trainer\n",
1013
- "gc.collect(); torch.cuda.empty_cache()"
1014
- ]
1015
- },
1016
- {
1017
- "cell_type": "markdown",
1018
- "metadata": {},
1019
- "source": [
1020
- "---\n",
1021
- "\n",
1022
- "## Cell 10: Probe Run (3 steps)"
1023
- ]
1024
- },
1025
- {
1026
- "cell_type": "code",
1027
- "execution_count": null,
1028
- "metadata": {},
1029
- "outputs": [],
1030
- "source": [
1031
- "FastLanguageModel.for_training(model)\n",
1032
- "\n",
1033
- "probe_config = GRPOConfig(\n",
1034
- " output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
1035
- " num_generations=NUM_GENERATIONS,\n",
1036
- " scale_rewards=SCALE_REWARDS,\n",
1037
- " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1038
- " max_steps=3,\n",
1039
- " temperature=TEMPERATURE,\n",
1040
- " num_train_epochs=NUM_EPOCHS,\n",
1041
- " per_device_train_batch_size=BATCH_SIZE,\n",
1042
- " gradient_accumulation_steps=GRAD_ACCUM,\n",
1043
- " learning_rate=LEARNING_RATE,\n",
1044
- " warmup_ratio=0.1,\n",
1045
- " lr_scheduler_type=\"cosine\",\n",
1046
- " fp16=False,\n",
1047
- " bf16=True,\n",
1048
- " logging_steps=1,\n",
1049
- " disable_tqdm=True,\n",
1050
- " logging_first_step=True,\n",
1051
- " save_steps=999,\n",
1052
- " report_to=\"none\",\n",
1053
- " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1054
- " seed=42,\n",
1055
- " remove_unused_columns=False,\n",
1056
- ")\n",
1057
- "\n",
1058
- "probe_trainer = GRPOTrainer(\n",
1059
- " model=model,\n",
1060
- " reward_funcs=commerce_reward_fn,\n",
1061
- " args=probe_config,\n",
1062
- " train_dataset=dataset,\n",
1063
- " tokenizer=tokenizer,\n",
1064
- ")\n",
1065
- "\n",
1066
- "t0 = time.time()\n",
1067
- "result = probe_trainer.train()\n",
1068
- "elapsed = time.time() - t0\n",
1069
- "\n",
1070
- "print(f\"\\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)\")\n",
1071
- "print(f\" Train loss: {result.training_loss:.6f}\")\n",
1072
- "print(f\" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h\")\n",
1073
- "\n",
1074
- "if abs(result.training_loss) < 1e-6:\n",
1075
- " print(\" ⚠️ Loss is near-zero — reward variance may be insufficient\")\n",
1076
- "else:\n",
1077
- " print(\" ✓ Loss is non-zero — GRPO has gradient signal\")\n",
1078
- "\n",
1079
- "del probe_trainer\n",
1080
- "gc.collect(); torch.cuda.empty_cache()"
1081
- ]
1082
- },
1083
- {
1084
- "cell_type": "markdown",
1085
- "metadata": {},
1086
- "source": [
1087
- "---\n",
1088
- "\n",
1089
- "## Cell 11: Full Training Run v3"
1090
- ]
1091
- },
1092
- {
1093
- "cell_type": "code",
1094
- "execution_count": null,
1095
- "metadata": {},
1096
- "outputs": [],
1097
- "source": [
1098
- "import wandb\n",
1099
- "\n",
1100
- "_wandb_key = os.environ.get(\"WANDB_API_KEY\", \"\").strip()\n",
1101
- "if not _wandb_key:\n",
1102
- " raise EnvironmentError(\"WANDB_API_KEY is not set.\")\n",
1103
- "wandb.login(key=_wandb_key, relogin=True)\n",
1104
- "print(f\"✓ W&B authenticated\")"
1105
- ]
1106
- },
1107
- {
1108
- "cell_type": "code",
1109
- "execution_count": null,
1110
- "metadata": {},
1111
- "outputs": [],
1112
- "source": [
1113
- "import shutil\n",
1114
- "import torch\n",
1115
- "from transformers import TrainerCallback\n",
1116
- "from trl import GRPOConfig, GRPOTrainer\n",
1117
- "\n",
1118
- "wandb.init(\n",
1119
- " project=WANDB_PROJECT,\n",
1120
- " name=f\"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}\",\n",
1121
- " config={\n",
1122
- " \"model_id\": MODEL_ID,\n",
1123
- " \"version\": \"v3\",\n",
1124
- " \"temperature\": TEMPERATURE,\n",
1125
- " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
1126
- " \"num_generations\": NUM_GENERATIONS,\n",
1127
- " \"learning_rate\": LEARNING_RATE,\n",
1128
- " \"beta\": BETA,\n",
1129
- " \"batch_size\": BATCH_SIZE,\n",
1130
- " \"grad_accum\": GRAD_ACCUM,\n",
1131
- " \"max_steps\": MAX_STEPS,\n",
1132
- " \"scale_rewards\": SCALE_REWARDS,\n",
1133
- " \"save_steps\": SAVE_STEPS,\n",
1134
- " \"eval_steps\": EVAL_STEPS,\n",
1135
- " \"eval_max_samples\": EVAL_MAX_SAMPLES,\n",
1136
- " \"eval_max_tokens\": EVAL_MAX_TOKENS,\n",
1137
- " \"eval_temperature\": EVAL_TEMPERATURE,\n",
1138
- " \"patience\": EARLY_STOPPING_PATIENCE,\n",
1139
- " \"delta\": EARLY_STOPPING_DELTA,\n",
1140
- " \"train_prompts\": len(train_dataset),\n",
1141
- " \"eval_prompts\": len(eval_dataset),\n",
1142
- " \"zero_adv_noise_std\": ZERO_ADV_NOISE_STD,\n",
1143
- " \"general_mix_ratio\": GENERAL_MIX_RATIO,\n",
1144
- " \"_ref_temperature\": \"Skywork-OR1 (2505.22312)\",\n",
1145
- " \"_ref_completion_length\": \"Dr. GRPO (2503.20783)\",\n",
1146
- " \"_ref_staged_rewards\": \"Reasoning-SQL (2503.23157)\",\n",
1147
- " \"_ref_zero_adv\": \"Skywork-OR1 (2505.22312)\",\n",
1148
- " },\n",
1149
- ")\n",
1150
- "print(f\"✓ W&B run: {wandb.run.url}\")\n",
1151
- "\n",
1152
- "FRESH = True\n",
1153
- "resume_from = None\n",
1154
- "if FRESH and CHECKPOINT_DIR.exists():\n",
1155
- " print(\"FRESH: deleting old checkpoints...\")\n",
1156
- " shutil.rmtree(CHECKPOINT_DIR)\n",
1157
- "elif CHECKPOINT_DIR.exists():\n",
1158
- " checkpoints = sorted(\n",
1159
- " [d for d in CHECKPOINT_DIR.iterdir()\n",
1160
- " if d.is_dir() and d.name.startswith(\"checkpoint-\")],\n",
1161
- " key=lambda d: int(d.name.split(\"-\")[-1]),\n",
1162
- " )\n",
1163
- " if checkpoints:\n",
1164
- " resume_from = str(checkpoints[-1])\n",
1165
- " print(f\"Resuming from: {resume_from}\")\n",
1166
- "\n",
1167
- "\n",
1168
- "class UnslothGRPOTrainer(GRPOTrainer):\n",
1169
- " \"\"\"Wraps generation with Unsloth for_inference()/for_training().\"\"\"\n",
1170
- " def _generate(self, prompts, images):\n",
1171
- " FastLanguageModel.for_inference(self.model)\n",
1172
- " try:\n",
1173
- " result = super()._generate(prompts, images)\n",
1174
- " finally:\n",
1175
- " FastLanguageModel.for_training(self.model)\n",
1176
- " return result\n",
1177
- "\n",
1178
- "\n",
1179
- "class EvalRewardCallback(TrainerCallback):\n",
1180
- " \"\"\"v3: deterministic eval, per-task breakdown, patience=15.\"\"\"\n",
1181
- " def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,\n",
1182
- " delta=EARLY_STOPPING_DELTA):\n",
1183
- " self.eval_records = eval_records\n",
1184
- " self.reward_fn = reward_fn\n",
1185
- " self.patience = patience\n",
1186
- " self.delta = delta\n",
1187
- " self.best_reward = -float(\"inf\")\n",
1188
- " self.no_improve_count = 0\n",
1189
- "\n",
1190
- " def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n",
1191
- " if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n",
1192
- " return control\n",
1193
- " tokenizer = processing_class\n",
1194
- " if tokenizer is None:\n",
1195
- " print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n",
1196
- " return control\n",
1197
- "\n",
1198
- " mean_reward, task_rewards = self._run_eval(model, tokenizer, args)\n",
1199
- " improved = mean_reward > self.best_reward + self.delta\n",
1200
- " status = \"↑ improved\" if improved else f\"↔ no gain ({self.no_improve_count + 1}/{self.patience})\"\n",
1201
- "\n",
1202
- " log_dict = {\n",
1203
- " \"eval/mean_reward\": mean_reward,\n",
1204
- " \"eval/best_reward\": max(self.best_reward, mean_reward),\n",
1205
- " \"eval/no_improve_count\": self.no_improve_count,\n",
1206
- " }\n",
1207
- " for task, rewards in task_rewards.items():\n",
1208
- " if rewards:\n",
1209
- " log_dict[f\"eval/{task}_reward\"] = sum(rewards) / len(rewards)\n",
1210
- " wandb.log(log_dict, step=state.global_step)\n",
1211
- "\n",
1212
- " print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n",
1213
- " for task, rewards in task_rewards.items():\n",
1214
- " if rewards:\n",
1215
- " print(f\" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})\")\n",
1216
- "\n",
1217
- " if improved:\n",
1218
- " self.best_reward = mean_reward\n",
1219
- " self.no_improve_count = 0\n",
1220
- " else:\n",
1221
- " self.no_improve_count += 1\n",
1222
- " if self.no_improve_count >= self.patience:\n",
1223
- " print(f\"[EarlyStopping] No improvement ≥ {self.delta} for {self.patience} consecutive evals. Halting.\")\n",
1224
- " wandb.log({\"early_stop/step\": state.global_step}, step=state.global_step)\n",
1225
- " control.should_training_stop = True\n",
1226
- " return control\n",
1227
- "\n",
1228
- " def _run_eval(self, model, tokenizer, args):\n",
1229
- " FastLanguageModel.for_inference(model)\n",
1230
- " rewards = []\n",
1231
- " task_rewards = {}\n",
1232
- " subset = self.eval_records[:EVAL_MAX_SAMPLES]\n",
1233
- " for record in subset:\n",
1234
- " msgs = record[\"prompt\"]\n",
1235
- " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
1236
- " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True,\n",
1237
- " max_length=args.max_prompt_length).to(model.device)\n",
1238
- " with torch.no_grad():\n",
1239
- " out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,\n",
1240
- " temperature=EVAL_TEMPERATURE, do_sample=True)\n",
1241
- " resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1242
- " r = self.reward_fn([resp], [text])[0]\n",
1243
- " rewards.append(r)\n",
1244
- " user_text = \" \".join(m.get(\"content\", \"\") for m in msgs if m.get(\"role\") == \"user\")\n",
1245
- " task = _classify_task_type(user_text)\n",
1246
- " task_rewards.setdefault(task, []).append(r)\n",
1247
- " FastLanguageModel.for_training(model)\n",
1248
- " mean = sum(rewards) / len(rewards) if rewards else 0.0\n",
1249
- " return mean, task_rewards\n",
1250
- "\n",
1251
- "\n",
1252
- "class EntropyMonitorCallback(TrainerCallback):\n",
1253
- " \"\"\"v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 §4).\"\"\"\n",
1254
- " def __init__(self):\n",
1255
- " self.consecutive_ceiling_hits = 0\n",
1256
- "\n",
1257
- " def on_log(self, args, state, control, logs=None, **kwargs):\n",
1258
- " if not logs:\n",
1259
- " return\n",
1260
- " step = state.global_step\n",
1261
- " monitor = {}\n",
1262
- " comp_len = logs.get(\"completion_length\", 0)\n",
1263
- " if comp_len > 0:\n",
1264
- " ratio = comp_len / MAX_COMPLETION_LENGTH\n",
1265
- " monitor[\"monitor/completion_ratio\"] = ratio\n",
1266
- " if ratio > 0.95:\n",
1267
- " self.consecutive_ceiling_hits += 1\n",
1268
- " if self.consecutive_ceiling_hits >= 3:\n",
1269
- " print(f\"⚠️ Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.\")\n",
1270
- " else:\n",
1271
- " self.consecutive_ceiling_hits = 0\n",
1272
- " reward_std = logs.get(\"reward_std\", logs.get(\"rewards/commerce_reward_fn/std\", 0))\n",
1273
- " if reward_std is not None:\n",
1274
- " monitor[\"monitor/reward_std\"] = reward_std\n",
1275
- " if reward_std < 0.01:\n",
1276
- " print(f\"⚠️ Step {step}: reward_std={reward_std:.4f} — near-zero variance\")\n",
1277
- " clip_high = logs.get(\"clip_ratio/high_mean\", 0)\n",
1278
- " clip_low = logs.get(\"clip_ratio/low_mean\", 0)\n",
1279
- " if clip_high is not None and clip_low is not None:\n",
1280
- " total_clip = clip_high + abs(clip_low)\n",
1281
- " monitor[\"monitor/total_clip_ratio\"] = total_clip\n",
1282
- " if total_clip > 0.01 and step > 10:\n",
1283
- " print(f\"✓ Step {step}: clip_ratio={total_clip:.3f} — policy is updating\")\n",
1284
- " if monitor and wandb.run:\n",
1285
- " wandb.log(monitor, step=step)\n",
1286
- "\n",
1287
- "\n",
1288
- "FastLanguageModel.for_training(model)\n",
1289
- "\n",
1290
- "grpo_config = GRPOConfig(\n",
1291
- " output_dir=str(CHECKPOINT_DIR),\n",
1292
- " num_generations=NUM_GENERATIONS,\n",
1293
- " scale_rewards=SCALE_REWARDS,\n",
1294
- " max_completion_length=MAX_COMPLETION_LENGTH,\n",
1295
- " temperature=TEMPERATURE,\n",
1296
- " max_steps=MAX_STEPS,\n",
1297
- " num_train_epochs=NUM_EPOCHS,\n",
1298
- " per_device_train_batch_size=BATCH_SIZE,\n",
1299
- " gradient_accumulation_steps=GRAD_ACCUM,\n",
1300
- " learning_rate=LEARNING_RATE,\n",
1301
- " warmup_ratio=0.1,\n",
1302
- " lr_scheduler_type=\"cosine\",\n",
1303
- " fp16=False,\n",
1304
- " bf16=True,\n",
1305
- " logging_steps=1,\n",
1306
- " logging_first_step=True,\n",
1307
- " disable_tqdm=True,\n",
1308
- " save_steps=SAVE_STEPS,\n",
1309
- " save_total_limit=SAVE_TOTAL_LIMIT,\n",
1310
- " save_only_model=True,\n",
1311
- " eval_steps=EVAL_STEPS,\n",
1312
- " report_to=\"wandb\",\n",
1313
- " max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,\n",
1314
- " seed=42,\n",
1315
- " remove_unused_columns=False,\n",
1316
- " **({\"use_vllm\": True, \"vllm_mode\": \"colocate\",\n",
1317
- " \"vllm_enable_sleep_mode\": True} if USE_VLLM else {}),\n",
1318
- ")\n",
1319
- "\n",
1320
- "eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)\n",
1321
- "entropy_cb = EntropyMonitorCallback()\n",
1322
- "\n",
1323
- "TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer\n",
1324
- "trainer = TrainerClass(\n",
1325
- " model=model,\n",
1326
- " reward_funcs=commerce_reward_fn,\n",
1327
- " args=grpo_config,\n",
1328
- " train_dataset=train_dataset,\n",
1329
- " processing_class=tokenizer,\n",
1330
- " callbacks=[eval_cb, entropy_cb],\n",
1331
- ")\n",
1332
- "\n",
1333
- "print(f\"{'='*70}\")\n",
1334
- "print(f\"GRPO v3 Training — Ready to Launch\")\n",
1335
- "print(f\"{'='*70}\")\n",
1336
- "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1337
- "print(f\" Max steps: {MAX_STEPS}\")\n",
1338
- "print(f\" Temperature: {TEMPERATURE} (v2 was 0.8)\")\n",
1339
- "print(f\" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)\")\n",
1340
- "print(f\" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)\")\n",
1341
- "print(f\" Learning rate: {LEARNING_RATE} (v2 was 5e-7)\")\n",
1342
- "print(f\" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})\")\n",
1343
- "print(f\" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples × {EVAL_MAX_TOKENS} tok)\")\n",
1344
- "print(f\" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)\")\n",
1345
- "print(f\" Resume: {resume_from is not None}\")\n",
1346
- "print(f\"{'='*70}\")\n",
1347
- "\n",
1348
- "t_start = time.time()\n",
1349
- "result = trainer.train(resume_from_checkpoint=resume_from)\n",
1350
- "elapsed = time.time() - t_start\n",
1351
- "\n",
1352
- "wandb.log({\n",
1353
- " \"train/final_loss\": result.training_loss,\n",
1354
- " \"train/duration_hours\": elapsed / 3600,\n",
1355
- " \"train/total_steps\": result.global_step,\n",
1356
- " \"eval/best_reward_final\": eval_cb.best_reward,\n",
1357
- "})\n",
1358
- "wandb.finish()\n",
1359
- "\n",
1360
- "print(f\"\\n{'='*70}\")\n",
1361
- "print(f\"GRPO v3 Training Complete\")\n",
1362
- "print(f\" Loss: {result.training_loss:.6f}\")\n",
1363
- "print(f\" Steps: {result.global_step}\")\n",
1364
- "print(f\" Duration: {elapsed/3600:.1f}h\")\n",
1365
- "print(f\" Best eval R: {eval_cb.best_reward:.4f}\")\n",
1366
- "print(f\" Trainer: {TrainerClass.__name__}\")\n",
1367
- "print(f\"{'='*70}\")"
1368
- ]
1369
- },
1370
- {
1371
- "cell_type": "markdown",
1372
- "metadata": {},
1373
- "source": [
1374
- "---\n",
1375
- "\n",
1376
- "## Cell 12: Save Adapter"
1377
- ]
1378
- },
1379
- {
1380
- "cell_type": "code",
1381
- "execution_count": null,
1382
- "metadata": {},
1383
- "outputs": [],
1384
- "source": [
1385
- "GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\n",
1386
- "model.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1387
- "tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))\n",
1388
- "\n",
1389
- "summary = {\n",
1390
- " \"model_id\": MODEL_ID,\n",
1391
- " \"sft_adapter\": str(SFT_ADAPTER_DIR),\n",
1392
- " \"method\": \"GRPO\",\n",
1393
- " \"version\": \"v3\",\n",
1394
- " \"train_loss\": result.training_loss,\n",
1395
- " \"best_eval_reward\": eval_cb.best_reward,\n",
1396
- " \"num_prompts\": len(train_dataset),\n",
1397
- " \"num_generations\": NUM_GENERATIONS,\n",
1398
- " \"scale_rewards\": SCALE_REWARDS,\n",
1399
- " \"temperature\": TEMPERATURE,\n",
1400
- " \"learning_rate\": LEARNING_RATE,\n",
1401
- " \"beta\": BETA,\n",
1402
- " \"max_completion_length\": MAX_COMPLETION_LENGTH,\n",
1403
- " \"max_steps\": MAX_STEPS,\n",
1404
- " \"actual_steps\": result.global_step,\n",
1405
- " \"epochs\": NUM_EPOCHS,\n",
1406
- " \"max_seq_length\": MAX_SEQ_LENGTH,\n",
1407
- " \"duration_seconds\": round(elapsed),\n",
1408
- " \"gpu\": \"L4\",\n",
1409
- " \"platform\": \"vertex-ai-workbench\",\n",
1410
- " \"v3_fixes\": [\n",
1411
- " \"temperature=1.0 (Skywork-OR1)\",\n",
1412
- " \"max_completion_length=4096 (Dr. GRPO)\",\n",
1413
- " \"learning_rate=2e-6 (4x v2)\",\n",
1414
- " \"beta=0.0 (Dr. GRPO)\",\n",
1415
- " \"staged rewards (Reasoning-SQL)\",\n",
1416
- " \"zero-advantage noise (Skywork-OR1)\",\n",
1417
- " \"entropy monitoring callback\",\n",
1418
- " ],\n",
1419
- "}\n",
1420
- "with open(GRPO_ADAPTER_DIR / \"training_summary.json\", \"w\") as f:\n",
1421
- " json.dump(summary, f, indent=2)\n",
1422
- "\n",
1423
- "print(f\"✓ Adapter saved to {GRPO_ADAPTER_DIR}\")\n",
1424
- "print(f\" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}\")"
1425
- ]
1426
- },
1427
- {
1428
- "cell_type": "markdown",
1429
- "metadata": {},
1430
- "source": [
1431
- "---\n",
1432
- "\n",
1433
- "## Cell 13: Validation"
1434
- ]
1435
- },
1436
- {
1437
- "cell_type": "code",
1438
- "execution_count": null,
1439
- "metadata": {},
1440
- "outputs": [],
1441
- "source": [
1442
- "FastLanguageModel.for_inference(model)\n",
1443
- "\n",
1444
- "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1445
- "\n",
1446
- "test_prompts = [\n",
1447
- " {\"role\": \"user\", \"content\": (\n",
1448
- " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1449
- " \"nota=2/5 | status=delivered\\ntítulo: decepcionado\\n\"\n",
1450
- " \"texto: Produto veio com defeito e o vendedor não respondeu.\\n\\n\"\n",
1451
- " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1452
- " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1453
- " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1454
- " )},\n",
1455
- " {\"role\": \"user\", \"content\": (\n",
1456
- " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1457
- " \"nota=5/5 | status=delivered\\ntítulo: adorei!\\n\"\n",
1458
- " \"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1459
- " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1460
- " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1461
- " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1462
- " )},\n",
1463
- " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
1464
- " {\"role\": \"user\", \"content\": \"Analise a retenção de clientes afetados por product_quality.\"},\n",
1465
- " {\"role\": \"user\", \"content\": (\n",
1466
- " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1467
- " \"- Reclamação: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1468
- " \"Este cliente deve receber uma notificação de reengajamento?\"\n",
1469
- " )},\n",
1470
- " {\"role\": \"user\", \"content\": \"Compare a satisfação de clientes em SP vs RJ.\"},\n",
1471
- " {\"role\": \"user\", \"content\": (\n",
1472
- " \"Crie uma notificação push de reengajamento para um cliente em SP \"\n",
1473
- " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1474
- " )},\n",
1475
- "]\n",
1476
- "\n",
1477
- "print(\"=\" * 70)\n",
1478
- "print(\"GRPO v3 Validation\")\n",
1479
- "print(\"=\" * 70)\n",
1480
- "\n",
1481
- "v3_rewards = []\n",
1482
- "for i, prompt in enumerate(test_prompts):\n",
1483
- " messages = [system_msg, prompt]\n",
1484
- " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1485
- " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1486
- "\n",
1487
- " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)\n",
1488
- " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
1489
- " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1490
- "\n",
1491
- " reward = commerce_reward_fn([response], [text])[0]\n",
1492
- " v3_rewards.append(reward)\n",
1493
- " answer = strip_think(response)\n",
1494
- " task = _classify_task_type(prompt[\"content\"])\n",
1495
- " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1496
- "\n",
1497
- " print(f\"\\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---\")\n",
1498
- " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1499
- " print(f\"Answer: {answer[:400]}\")\n",
1500
- "\n",
1501
- "print(f\"\\n{'='*70}\")\n",
1502
- "print(f\"v3 Validation Summary\")\n",
1503
- "print(f\"{'='*70}\")\n",
1504
- "print(f\" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1505
- "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1506
- "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1507
- "print()\n",
1508
- "print(f\" Comparison to baselines:\")\n",
1509
- "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1510
- "print(f\" GRPO v2 validation: mean=0.54\")\n",
1511
- "print(f\" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1512
- "v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100\n",
1513
- "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")"
1514
- ]
1515
- }
1516
- ]
1517
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
notebooks/grpo_vertex_v3.ipynb CHANGED
@@ -204,11 +204,78 @@
204
  " \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n",
205
  "]\n",
206
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  "SYSTEM_PT = (\n",
208
  " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
209
  " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\"\n",
210
  ")\n",
211
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  "# ══════════════════════════════════════════════════════════════════════════════\n",
213
  "# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)\n",
214
  "# ══════════════════════════════════════════════════════════════════════════════\n",
@@ -351,6 +418,7 @@
351
  },
352
  {
353
  "cell_type": "markdown",
 
354
  "metadata": {},
355
  "source": [
356
  "---\n",
@@ -361,6 +429,7 @@
361
  {
362
  "cell_type": "code",
363
  "execution_count": null,
 
364
  "metadata": {},
365
  "outputs": [],
366
  "source": [
@@ -393,6 +462,7 @@
393
  },
394
  {
395
  "cell_type": "markdown",
 
396
  "metadata": {},
397
  "source": [
398
  "---\n",
@@ -405,6 +475,7 @@
405
  {
406
  "cell_type": "code",
407
  "execution_count": null,
 
408
  "metadata": {},
409
  "outputs": [],
410
  "source": [
@@ -444,6 +515,7 @@
444
  },
445
  {
446
  "cell_type": "markdown",
 
447
  "metadata": {},
448
  "source": [
449
  "---\n",
@@ -454,6 +526,7 @@
454
  {
455
  "cell_type": "code",
456
  "execution_count": null,
 
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
@@ -511,6 +584,7 @@
511
  },
512
  {
513
  "cell_type": "markdown",
 
514
  "metadata": {},
515
  "source": [
516
  "---\n",
@@ -526,6 +600,7 @@
526
  {
527
  "cell_type": "code",
528
  "execution_count": null,
 
529
  "metadata": {},
530
  "outputs": [],
531
  "source": [
@@ -782,6 +857,44 @@
782
  " return min(r_format + r_partial + r_task, 1.0)\n",
783
  "\n",
784
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785
  "def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n",
786
  " \"\"\"\n",
787
  " Master reward function v3: dispatches by task type + zero-advantage noise.\n",
@@ -801,17 +914,20 @@
801
  " task = _classify_task_type(prompt_text)\n",
802
  "\n",
803
  " if task == \"extraction\":\n",
804
- " rewards.append(reward_extraction(comp_text))\n",
805
  " elif task == \"sql_qa\":\n",
806
- " rewards.append(reward_sql_qa(comp_text))\n",
807
  " elif task == \"insights\":\n",
808
- " rewards.append(reward_insights(comp_text))\n",
809
  " elif task == \"push\":\n",
810
- " rewards.append(reward_push(comp_text))\n",
811
  " else:\n",
812
- " r = 0.15 if has_think_block(comp_text) else 0.0\n",
813
- " r += 0.2 if comp_text.strip() else 0.0\n",
814
- " rewards.append(r)\n",
 
 
 
815
  "\n",
816
  " # ── v3: Zero-advantage noise injection ────────────────��───────────────────\n",
817
  " if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:\n",
@@ -831,6 +947,7 @@
831
  },
832
  {
833
  "cell_type": "markdown",
 
834
  "metadata": {},
835
  "source": [
836
  "---\n",
@@ -843,6 +960,7 @@
843
  {
844
  "cell_type": "code",
845
  "execution_count": null,
 
846
  "metadata": {},
847
  "outputs": [],
848
  "source": [
@@ -862,6 +980,21 @@
862
  "\n",
863
  "print(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n",
864
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
  "rng = random.Random(42)\n",
866
  "cal_samples = []\n",
867
  "for task_type in [\"extraction\", \"extraction\", \"sql_qa\", \"sql_qa\", \"insights\", \"insights\", \"push\", \"push\"]:\n",
@@ -874,6 +1007,13 @@
874
  "cal_rewards = []\n",
875
  "cal_rows = [] # collect per-sample data for W&B Table\n",
876
  "for i, msgs in enumerate(cal_samples):\n",
 
 
 
 
 
 
 
877
  " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
878
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
879
  " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
@@ -885,7 +1025,6 @@
885
  " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
886
  " has_answer = \"</think>\" in response\n",
887
  " answer_preview = strip_think(response)[:100] if has_answer else \"[stuck in <think>]\"\n",
888
- " task = _classify_task_type(text)\n",
889
  " print(f\" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'HIT' if hit_ceiling else 'ok':6s} | {answer_preview}\")\n",
890
  "\n",
891
  " cal_rows.append([i, task, r, gen_tokens, hit_ceiling, has_answer, answer_preview])\n",
@@ -977,6 +1116,20 @@
977
  " elif general_mix > 0:\n",
978
  " print(f\" general_reasoning.jsonl not found — skipping mix\")\n",
979
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980
  " task_dist = {}\n",
981
  " for record in train_records:\n",
982
  " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
@@ -1471,40 +1624,42 @@
1471
  ]
1472
  },
1473
  {
1474
- "cell_type": "markdown",
 
1475
  "metadata": {},
 
1476
  "source": [
1477
  "FastLanguageModel.for_inference(model)\n",
1478
  "\n",
1479
- "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1480
  "\n",
1481
  "test_prompts = [\n",
1482
  " {\"role\": \"user\", \"content\": (\n",
1483
- " \"Analise esta avaliacao de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1484
- " \"nota=2/5 | status=delivered\\ntitulo: decepcionado\\n\"\n",
1485
- " \"texto: Produto veio com defeito e o vendedor nao respondeu.\\n\\n\"\n",
1486
  " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1487
  " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1488
  " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1489
  " )},\n",
1490
  " {\"role\": \"user\", \"content\": (\n",
1491
- " \"Analise esta avaliacao de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1492
- " \"nota=5/5 | status=delivered\\ntitulo: adorei!\\n\"\n",
1493
- " \"texto: Entrega rapida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1494
  " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1495
  " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1496
  " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1497
  " )},\n",
1498
- " {\"role\": \"user\", \"content\": \"Quais sao as categorias de reclamacao mais frequentes e como afetam a nota media?\"},\n",
1499
- " {\"role\": \"user\", \"content\": \"Analise a retencao de clientes afetados por product_quality.\"},\n",
1500
  " {\"role\": \"user\", \"content\": (\n",
1501
  " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1502
- " \"- Reclamacao: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1503
- " \"Este cliente deve receber uma notificacao de reengajamento?\"\n",
1504
  " )},\n",
1505
- " {\"role\": \"user\", \"content\": \"Compare a satisfacao de clientes em SP vs RJ.\"},\n",
1506
  " {\"role\": \"user\", \"content\": (\n",
1507
- " \"Crie uma notificacao push de reengajamento para um cliente em SP \"\n",
1508
  " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1509
  " )},\n",
1510
  "]\n",
@@ -1514,8 +1669,9 @@
1514
  "print(\"=\" * 70)\n",
1515
  "\n",
1516
  "v3_rewards = []\n",
1517
- "val_rows = []\n",
1518
  "for i, prompt in enumerate(test_prompts):\n",
 
 
1519
  " messages = [system_msg, prompt]\n",
1520
  " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1521
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
@@ -1534,23 +1690,18 @@
1534
  " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1535
  " print(f\"Answer: {answer[:400]}\")\n",
1536
  "\n",
1537
- " val_rows.append([i + 1, task, reward, gen_tokens, hit_ceiling,\n",
1538
- " prompt[\"content\"][:120], answer[:500]])\n",
1539
- "\n",
1540
- "v3_mean = sum(v3_rewards) / len(v3_rewards)\n",
1541
- "v3_vs_v2 = (v3_mean - 0.54) / 0.54 * 100\n",
1542
- "\n",
1543
  "print(f\"\\n{'='*70}\")\n",
1544
  "print(f\"v3 Validation Summary\")\n",
1545
  "print(f\"{'='*70}\")\n",
1546
- "print(f\" Mean reward: {v3_mean:.3f}\")\n",
1547
  "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1548
  "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1549
  "print()\n",
1550
  "print(f\" Comparison to baselines:\")\n",
1551
  "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1552
  "print(f\" GRPO v2 validation: mean=0.54\")\n",
1553
- "print(f\" GRPO v3 validation: mean={v3_mean:.3f}\")\n",
 
1554
  "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")\n",
1555
  "\n",
1556
  "# ── Log validation results to W&B ────────────────────────────────────────────\n",
@@ -1574,86 +1725,6 @@
1574
  "wandb.finish()\n",
1575
  "print(f\"\\n✓ W&B run finalized — all outputs saved\")"
1576
  ]
1577
- },
1578
- {
1579
- "cell_type": "code",
1580
- "execution_count": null,
1581
- "metadata": {},
1582
- "outputs": [],
1583
- "source": [
1584
- "FastLanguageModel.for_inference(model)\n",
1585
- "\n",
1586
- "system_msg = {\"role\": \"system\", \"content\": SYSTEM_PT}\n",
1587
- "\n",
1588
- "test_prompts = [\n",
1589
- " {\"role\": \"user\", \"content\": (\n",
1590
- " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1591
- " \"nota=2/5 | status=delivered\\ntítulo: decepcionado\\n\"\n",
1592
- " \"texto: Produto veio com defeito e o vendedor não respondeu.\\n\\n\"\n",
1593
- " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1594
- " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1595
- " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1596
- " )},\n",
1597
- " {\"role\": \"user\", \"content\": (\n",
1598
- " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1599
- " \"nota=5/5 | status=delivered\\ntítulo: adorei!\\n\"\n",
1600
- " \"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1601
- " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1602
- " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1603
- " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1604
- " )},\n",
1605
- " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
1606
- " {\"role\": \"user\", \"content\": \"Analise a retenção de clientes afetados por product_quality.\"},\n",
1607
- " {\"role\": \"user\", \"content\": (\n",
1608
- " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1609
- " \"- Reclamação: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1610
- " \"Este cliente deve receber uma notificação de reengajamento?\"\n",
1611
- " )},\n",
1612
- " {\"role\": \"user\", \"content\": \"Compare a satisfação de clientes em SP vs RJ.\"},\n",
1613
- " {\"role\": \"user\", \"content\": (\n",
1614
- " \"Crie uma notificação push de reengajamento para um cliente em SP \"\n",
1615
- " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1616
- " )},\n",
1617
- "]\n",
1618
- "\n",
1619
- "print(\"=\" * 70)\n",
1620
- "print(\"GRPO v3 Validation\")\n",
1621
- "print(\"=\" * 70)\n",
1622
- "\n",
1623
- "v3_rewards = []\n",
1624
- "for i, prompt in enumerate(test_prompts):\n",
1625
- " messages = [system_msg, prompt]\n",
1626
- " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1627
- " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1628
- "\n",
1629
- " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)\n",
1630
- " gen_tokens = outputs.shape[1] - inputs[\"input_ids\"].shape[1]\n",
1631
- " response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
1632
- "\n",
1633
- " reward = commerce_reward_fn([response], [text])[0]\n",
1634
- " v3_rewards.append(reward)\n",
1635
- " answer = strip_think(response)\n",
1636
- " task = _classify_task_type(prompt[\"content\"])\n",
1637
- " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1638
- "\n",
1639
- " print(f\"\\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---\")\n",
1640
- " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1641
- " print(f\"Answer: {answer[:400]}\")\n",
1642
- "\n",
1643
- "print(f\"\\n{'='*70}\")\n",
1644
- "print(f\"v3 Validation Summary\")\n",
1645
- "print(f\"{'='*70}\")\n",
1646
- "print(f\" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1647
- "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1648
- "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1649
- "print()\n",
1650
- "print(f\" Comparison to baselines:\")\n",
1651
- "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1652
- "print(f\" GRPO v2 validation: mean=0.54\")\n",
1653
- "print(f\" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1654
- "v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100\n",
1655
- "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")"
1656
- ]
1657
  }
1658
  ],
1659
  "metadata": {
 
204
  " \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n",
205
  "]\n",
206
  "\n",
207
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
208
+ "# v3: TASK-AWARE SYSTEM PROMPTS\n",
209
+ "# ══════════════════════════════════════════════════════════════════════════════\n",
210
+ "# Research basis:\n",
211
+ "# - OptimalThinkingBench (2508.13141): \"Don't overthink\" → -23% tokens, +7.7pp accuracy on Qwen3\n",
212
+ "# - Mid-Think (2601.07036): task-specific thinking control in GRPO → +2.6pp AIME, -15% train time\n",
213
+ "# - L1 (2503.04697): token budgets in prompts work when trained with RL reward signal\n",
214
+ "# - User's proven extraction prompt: XML-tagged structure + few-shot + schema enforcement\n",
215
+ "\n",
216
+ "SYSTEM_EXTRACTION = (\n",
217
+ " \"Você é um motor de extração de dados de e-commerce brasileiro. \"\n",
218
+ " \"Retorne APENAS um objeto JSON válido, sem nenhum texto antes ou depois. \"\n",
219
+ " \"NÃO USE blocos de código markdown (` `` json). \"\n",
220
+ " \"O primeiro caractere da sua resposta deve ser { e o último deve ser }. \"\n",
221
+ " \"Campos não mencionados na avaliação devem ser null — nunca invente valores. \"\n",
222
+ " \"Sem explicação. Sem comentários. Não pense em excesso.\"\n",
223
+ ")\n",
224
+ "\n",
225
+ "SYSTEM_SQL = (\n",
226
+ " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
227
+ " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\\n\\n\"\n",
228
+ " \"Para consultas e análises de dados: pense brevemente sobre a estrutura necessária, \"\n",
229
+ " \"depois apresente a resposta de forma direta com números e dados concretos. \"\n",
230
+ " \"Seja conciso no raciocínio. Não pense em excesso.\"\n",
231
+ ")\n",
232
+ "\n",
233
+ "SYSTEM_INSIGHTS = (\n",
234
+ " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
235
+ " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\\n\\n\"\n",
236
+ " \"Para análises estratégicas: raciocine de forma estruturada e concisa, \"\n",
237
+ " \"focando nos pontos principais e recomendações acionáveis. \"\n",
238
+ " \"Use no máximo 500 tokens para raciocinar antes de responder.\"\n",
239
+ ")\n",
240
+ "\n",
241
+ "SYSTEM_PUSH = (\n",
242
+ " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
243
+ " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\\n\\n\"\n",
244
+ " \"Para notificações push: seja direto e criativo. \"\n",
245
+ " \"A notificação deve ter no máximo 120 caracteres. \"\n",
246
+ " \"Responda diretamente sem pensar em excesso.\"\n",
247
+ ")\n",
248
+ "\n",
249
+ "# Legacy fallback — used only in cells that don't have task context\n",
250
  "SYSTEM_PT = (\n",
251
  " \"Você é um assistente de IA especializado em análise de e-commerce brasileiro. \"\n",
252
  " \"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\"\n",
253
  ")\n",
254
  "\n",
255
+ "def get_system_prompt(task_type: str) -> str:\n",
256
+ " \"\"\"Return task-optimized system prompt.\"\"\"\n",
257
+ " return {\n",
258
+ " \"extraction\": SYSTEM_EXTRACTION,\n",
259
+ " \"sql_qa\": SYSTEM_SQL,\n",
260
+ " \"insights\": SYSTEM_INSIGHTS,\n",
261
+ " \"push\": SYSTEM_PUSH,\n",
262
+ " }.get(task_type, SYSTEM_PT)\n",
263
+ "\n",
264
+ "# ── Think token budgets per task (for reward function) ────────────────────────\n",
265
+ "# These are soft targets — the reward function nudges, not enforces\n",
266
+ "THINK_BUDGETS = {\n",
267
+ " \"extraction\": 150, # Extraction barely needs thinking — pattern matching\n",
268
+ " \"push\": 100, # Push is creative writing, not reasoning\n",
269
+ " \"sql_qa\": 400, # SQL benefits from brief query planning\n",
270
+ " \"insights\": 800, # Insights need structured multi-step analysis\n",
271
+ "}\n",
272
+ "\n",
273
+ "print(\"✓ v3 Task-aware system prompts defined\")\n",
274
+ "print(f\" extraction: '{SYSTEM_EXTRACTION[:60]}...'\")\n",
275
+ "print(f\" sql_qa: '{SYSTEM_SQL[:60]}...'\")\n",
276
+ "print(f\" insights: '{SYSTEM_INSIGHTS[:60]}...'\")\n",
277
+ "print(f\" push: '{SYSTEM_PUSH[:60]}...'\")\n",
278
+ "\n",
279
  "# ══════════════════════════════════════════════════════════════════════════════\n",
280
  "# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)\n",
281
  "# ══════════════════════════════════════════════════════════════════════════════\n",
 
418
  },
419
  {
420
  "cell_type": "markdown",
421
+ "id": "1187f9d3",
422
  "metadata": {},
423
  "source": [
424
  "---\n",
 
429
  {
430
  "cell_type": "code",
431
  "execution_count": null,
432
+ "id": "4d77bfc1",
433
  "metadata": {},
434
  "outputs": [],
435
  "source": [
 
462
  },
463
  {
464
  "cell_type": "markdown",
465
+ "id": "e0bcb82e",
466
  "metadata": {},
467
  "source": [
468
  "---\n",
 
475
  {
476
  "cell_type": "code",
477
  "execution_count": null,
478
+ "id": "9baaaedb",
479
  "metadata": {},
480
  "outputs": [],
481
  "source": [
 
515
  },
516
  {
517
  "cell_type": "markdown",
518
+ "id": "fe81d051",
519
  "metadata": {},
520
  "source": [
521
  "---\n",
 
526
  {
527
  "cell_type": "code",
528
  "execution_count": null,
529
+ "id": "5161aca2",
530
  "metadata": {},
531
  "outputs": [],
532
  "source": [
 
584
  },
585
  {
586
  "cell_type": "markdown",
587
+ "id": "29020870",
588
  "metadata": {},
589
  "source": [
590
  "---\n",
 
600
  {
601
  "cell_type": "code",
602
  "execution_count": null,
603
+ "id": "f1ec57fb",
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
 
857
  " return min(r_format + r_partial + r_task, 1.0)\n",
858
  "\n",
859
  "\n",
860
+ "def reward_think_efficiency(completion: str, task_type: str) -> float:\n",
861
+ " \"\"\"\n",
862
+ " Reward concise thinking, penalize bloated <think> blocks.\n",
863
+ " \n",
864
+ " v3 NEW — Research basis:\n",
865
+ " - OptimalThinkingBench (2508.13141): overthinking hurts accuracy on simple tasks\n",
866
+ " - L1 (2503.04697): token budget rewards teach models to control reasoning length\n",
867
+ " - Train Long Think Short (2508.08940): triangular length reward around target budget\n",
868
+ " \n",
869
+ " Returns: -0.05 to +0.1 (small component — nudge, not dominate)\n",
870
+ " \"\"\"\n",
871
+ " think_match = re.search(r\"<think>(.*?)</think>\", completion, re.DOTALL)\n",
872
+ " budget = THINK_BUDGETS.get(task_type, 500)\n",
873
+ " \n",
874
+ " if not think_match:\n",
875
+ " # No think block at all\n",
876
+ " if task_type in (\"extraction\", \"push\"):\n",
877
+ " return 0.1 # Great — these tasks don't need thinking\n",
878
+ " else:\n",
879
+ " return 0.0 # Neutral for analytical tasks\n",
880
+ " \n",
881
+ " think_content = think_match.group(1).strip()\n",
882
+ " think_chars = len(think_content) # chars as proxy (cheaper than tokenizing)\n",
883
+ " # Rough conversion: ~4 chars per token for Portuguese\n",
884
+ " think_tokens_approx = think_chars / 4\n",
885
+ " \n",
886
+ " if think_tokens_approx <= budget:\n",
887
+ " # Within budget — reward proportional to how concise\n",
888
+ " return 0.1\n",
889
+ " elif think_tokens_approx <= budget * 2:\n",
890
+ " # Over budget but not catastrophic — linear decay\n",
891
+ " overshoot = (think_tokens_approx - budget) / budget\n",
892
+ " return 0.1 * (1.0 - overshoot) # 0.1 → 0.0\n",
893
+ " else:\n",
894
+ " # Way over budget — mild penalty\n",
895
+ " return -0.05\n",
896
+ "\n",
897
+ "\n",
898
  "def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n",
899
  " \"\"\"\n",
900
  " Master reward function v3: dispatches by task type + zero-advantage noise.\n",
 
914
  " task = _classify_task_type(prompt_text)\n",
915
  "\n",
916
  " if task == \"extraction\":\n",
917
+ " task_r = reward_extraction(comp_text)\n",
918
  " elif task == \"sql_qa\":\n",
919
+ " task_r = reward_sql_qa(comp_text)\n",
920
  " elif task == \"insights\":\n",
921
+ " task_r = reward_insights(comp_text)\n",
922
  " elif task == \"push\":\n",
923
+ " task_r = reward_push(comp_text)\n",
924
  " else:\n",
925
+ " task_r = 0.15 if has_think_block(comp_text) else 0.0\n",
926
+ " task_r += 0.2 if comp_text.strip() else 0.0\n",
927
+ "\n",
928
+ " # v3: Think efficiency bonus/penalty (small weight — nudge, not dominate)\n",
929
+ " think_r = reward_think_efficiency(comp_text, task)\n",
930
+ " rewards.append(task_r + think_r)\n",
931
  "\n",
932
  " # ── v3: Zero-advantage noise injection ────────────────��───────────────────\n",
933
  " if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:\n",
 
947
  },
948
  {
949
  "cell_type": "markdown",
950
+ "id": "6f3d27d3",
951
  "metadata": {},
952
  "source": [
953
  "---\n",
 
960
  {
961
  "cell_type": "code",
962
  "execution_count": null,
963
+ "id": "e992af27",
964
  "metadata": {},
965
  "outputs": [],
966
  "source": [
 
980
  "\n",
981
  "print(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n",
982
  "\n",
983
+ "def inject_task_system_prompt(msgs, task_type):\n",
984
+ " \"\"\"Replace generic system prompt with task-specific one.\"\"\"\n",
985
+ " new_msgs = []\n",
986
+ " system_prompt = get_system_prompt(task_type)\n",
987
+ " has_system = False\n",
988
+ " for m in msgs:\n",
989
+ " if m[\"role\"] == \"system\":\n",
990
+ " new_msgs.append({\"role\": \"system\", \"content\": system_prompt})\n",
991
+ " has_system = True\n",
992
+ " else:\n",
993
+ " new_msgs.append(m)\n",
994
+ " if not has_system:\n",
995
+ " new_msgs.insert(0, {\"role\": \"system\", \"content\": system_prompt})\n",
996
+ " return new_msgs\n",
997
+ "\n",
998
  "rng = random.Random(42)\n",
999
  "cal_samples = []\n",
1000
  "for task_type in [\"extraction\", \"extraction\", \"sql_qa\", \"sql_qa\", \"insights\", \"insights\", \"push\", \"push\"]:\n",
 
1007
  "cal_rewards = []\n",
1008
  "cal_rows = [] # collect per-sample data for W&B Table\n",
1009
  "for i, msgs in enumerate(cal_samples):\n",
1010
+ " # Determine task type from user content\n",
1011
+ " user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n",
1012
+ " task = _classify_task_type(user_text)\n",
1013
+ " \n",
1014
+ " # v3: Inject task-aware system prompt\n",
1015
+ " msgs = inject_task_system_prompt(msgs, task)\n",
1016
+ " \n",
1017
  " text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
1018
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
1019
  " outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)\n",
 
1025
  " hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH\n",
1026
  " has_answer = \"</think>\" in response\n",
1027
  " answer_preview = strip_think(response)[:100] if has_answer else \"[stuck in <think>]\"\n",
 
1028
  " print(f\" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'HIT' if hit_ceiling else 'ok':6s} | {answer_preview}\")\n",
1029
  "\n",
1030
  " cal_rows.append([i, task, r, gen_tokens, hit_ceiling, has_answer, answer_preview])\n",
 
1116
  " elif general_mix > 0:\n",
1117
  " print(f\" general_reasoning.jsonl not found — skipping mix\")\n",
1118
  "\n",
1119
+ " # v3: Inject task-aware system prompts into each training record\n",
1120
+ " for i, record in enumerate(train_records):\n",
1121
+ " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
1122
+ " task = _classify_task_type(user_text)\n",
1123
+ " train_records[i] = inject_task_system_prompt(record, task)\n",
1124
+ " \n",
1125
+ " # Same for eval records\n",
1126
+ " for i, record in enumerate(eval_records):\n",
1127
+ " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
1128
+ " task = _classify_task_type(user_text)\n",
1129
+ " eval_records[i] = inject_task_system_prompt(record, task)\n",
1130
+ " \n",
1131
+ " print(f\" ✓ Task-aware system prompts injected\")\n",
1132
+ "\n",
1133
  " task_dist = {}\n",
1134
  " for record in train_records:\n",
1135
  " user_text = \" \".join(m[\"content\"] for m in record if m[\"role\"] == \"user\")\n",
 
1624
  ]
1625
  },
1626
  {
1627
+ "cell_type": "code",
1628
+ "execution_count": null,
1629
  "metadata": {},
1630
+ "outputs": [],
1631
  "source": [
1632
  "FastLanguageModel.for_inference(model)\n",
1633
  "\n",
1634
+ "# REMOVED static system_msg\n",
1635
  "\n",
1636
  "test_prompts = [\n",
1637
  " {\"role\": \"user\", \"content\": (\n",
1638
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1639
+ " \"nota=2/5 | status=delivered\\ntítulo: decepcionado\\n\"\n",
1640
+ " \"texto: Produto veio com defeito e o vendedor não respondeu.\\n\\n\"\n",
1641
  " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1642
  " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1643
  " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1644
  " )},\n",
1645
  " {\"role\": \"user\", \"content\": (\n",
1646
+ " \"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\\n\\n\"\n",
1647
+ " \"nota=5/5 | status=delivered\\ntítulo: adorei!\\n\"\n",
1648
+ " \"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\\n\\n\"\n",
1649
  " \"Retorne um objeto JSON com exatamente estas chaves:\\n\"\n",
1650
  " \"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, \"\n",
1651
  " \"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend\"\n",
1652
  " )},\n",
1653
+ " {\"role\": \"user\", \"content\": \"Quais são as categorias de reclamação mais frequentes e como afetam a nota média?\"},\n",
1654
+ " {\"role\": \"user\", \"content\": \"Analise a retenção de clientes afetados por product_quality.\"},\n",
1655
  " {\"role\": \"user\", \"content\": (\n",
1656
  " \"Perfil do cliente:\\n- Estado: MG\\n- Valor do pedido: R$150\\n\"\n",
1657
+ " \"- Reclamação: produto com defeito\\n- Nota: 1.0/5\\n\\n\"\n",
1658
+ " \"Este cliente deve receber uma notificação de reengajamento?\"\n",
1659
  " )},\n",
1660
+ " {\"role\": \"user\", \"content\": \"Compare a satisfação de clientes em SP vs RJ.\"},\n",
1661
  " {\"role\": \"user\", \"content\": (\n",
1662
+ " \"Crie uma notificação push de reengajamento para um cliente em SP \"\n",
1663
  " \"que reclamou de atraso na entrega. Nota: 2/5.\"\n",
1664
  " )},\n",
1665
  "]\n",
 
1669
  "print(\"=\" * 70)\n",
1670
  "\n",
1671
  "v3_rewards = []\n",
 
1672
  "for i, prompt in enumerate(test_prompts):\n",
1673
+ " task = _classify_task_type(prompt[\"content\"])\n",
1674
+ " system_msg = {\"role\": \"system\", \"content\": get_system_prompt(task)}\n",
1675
  " messages = [system_msg, prompt]\n",
1676
  " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1677
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
 
1690
  " print(f\"Prompt: {prompt['content'][:80]}...\")\n",
1691
  " print(f\"Answer: {answer[:400]}\")\n",
1692
  "\n",
 
 
 
 
 
 
1693
  "print(f\"\\n{'='*70}\")\n",
1694
  "print(f\"v3 Validation Summary\")\n",
1695
  "print(f\"{'='*70}\")\n",
1696
+ "print(f\" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1697
  "print(f\" Min: {min(v3_rewards):.3f}\")\n",
1698
  "print(f\" Max: {max(v3_rewards):.3f}\")\n",
1699
  "print()\n",
1700
  "print(f\" Comparison to baselines:\")\n",
1701
  "print(f\" SFT calibration (Cell 7): mean=0.38\")\n",
1702
  "print(f\" GRPO v2 validation: mean=0.54\")\n",
1703
+ "print(f\" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}\")\n",
1704
+ "v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100\n",
1705
  "print(f\" v3 vs v2: {v3_vs_v2:+.1f}%\")\n",
1706
  "\n",
1707
  "# ── Log validation results to W&B ────────────────────────────────────────────\n",
 
1725
  "wandb.finish()\n",
1726
  "print(f\"\\n✓ W&B run finalized — all outputs saved\")"
1727
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1728
  }
1729
  ],
1730
  "metadata": {