rtferraz commited on
Commit
a6a8b11
·
verified ·
1 Parent(s): 042d2b9

feat: add GRPO v3 implementation with entropy collapse fixes

Browse files
Files changed (1) hide show
  1. grpo_vertex_v3.md +1322 -0
grpo_vertex_v3.md ADDED
@@ -0,0 +1,1322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)
2
+
3
+ **v3 changes over v2 — grounded in published research:**
4
+
5
+ | Change | v2 Value | v3 Value | Paper Reference |
6
+ |--------|----------|----------|----------------|
7
+ | Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |
8
+ | Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |
9
+ | Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |
10
+ | Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |
11
+ | β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |
12
+ | Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |
13
+ | Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |
14
+ | Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |
15
+ | Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |
16
+ | Early stopping patience | 10 | **15** | More runway for longer completions |
17
+ | Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |
18
+ | Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |
19
+ | General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |
20
+
21
+ **Prerequisites:**
22
+ - Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`
23
+ - Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`
24
+ - **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix
25
+
26
+ **Hardware:** L4 (24GB), PyTorch kernel, bf16 supported
27
+
28
+ ---
29
+
30
+ ## Cell 1: Dependencies
31
+
32
+ Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:
33
+
34
+ ```python
35
+ # Cell 1a — Nuke everything ML-related
36
+ !pip uninstall -y torch torchvision torchaudio \
37
+ unsloth unsloth-zoo \
38
+ trl transformers peft accelerate \
39
+ bitsandbytes vllm vllm-flash-attn \
40
+ datasets tokenizers safetensors huggingface-hub \
41
+ wandb xformers triton \
42
+ cuda-bindings cuda-python \
43
+ sentencepiece protobuf \
44
+ 2>/dev/null
45
+ ```
46
+
47
+ ```python
48
+ # Cell 1b — Kill any stragglers
49
+ !pip freeze | grep -iE "torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate" | xargs pip uninstall -y 2>/dev/null
50
+ ```
51
+
52
+ ```python
53
+ # Cell 1c — Purge cache
54
+ !pip cache purge
55
+ ```
56
+
57
+ **⚠️ Restart kernel again**, then:
58
+
59
+ ```python
60
+ # Cell 1d — Clean install, correct order
61
+ !pip install "unsloth"
62
+ ```
63
+
64
+ ```python
65
+ # Cell 1e — Pin TRL (Unsloth may pull a different version)
66
+ !pip install "trl==0.24.0" --no-deps
67
+ ```
68
+
69
+ ```python
70
+ # Cell 1f — Extra deps
71
+ !pip install "rich" "wandb"
72
+ ```
73
+
74
+ ---
75
+
76
+ ## Cell 2: Hello World — GPU + Unsloth Verification
77
+
78
+ ```python
79
+ import torch
80
+
81
+ print(f"CUDA available: {torch.cuda.is_available()}")
82
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
83
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
84
+ print(f"bf16 support: {torch.cuda.is_bf16_supported()}")
85
+
86
+ from unsloth import FastLanguageModel
87
+ print("\n✓ Unsloth loaded successfully")
88
+
89
+ import trl
90
+ print(f"✓ TRL version: {trl.__version__}")
91
+
92
+ import transformers
93
+ print(f"✓ Transformers version: {transformers.__version__}")
94
+ ```
95
+
96
+ ---
97
+
98
+ ## Cell 3: Config + Constants
99
+
100
+ ```python
101
+ import os
102
+ os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
103
+
104
+ import json
105
+ import re
106
+ import time
107
+ import random
108
+ import gc
109
+ from pathlib import Path
110
+
111
+ # ══════════════════════════════════════════════════════════════════════════════
112
+ # v3 CONFIG — Every change is annotated with paper reference
113
+ # ══════════════════════════════════════════════════════════════════════════════
114
+
115
+ MODEL_ID = "Polygl0t/Tucano2-qwen-3.7B-Think"
116
+ MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt
117
+
118
+ # ── Paths ─────────────────────────────────────────────────────────────────────
119
+ DATA_DIR = Path("/home/jupyter/tucano2/data")
120
+ MODELS_DIR = Path("/home/jupyter/tucano2/models")
121
+ SFT_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-sft"
122
+ GRPO_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-grpo-v3" # v3: separate dir from v2
123
+ CHECKPOINT_DIR = GRPO_ADAPTER_DIR / "checkpoints"
124
+
125
+ # ── Training data ─────────────────────────────────────────────────────────────
126
+ GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)
127
+ GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)
128
+
129
+ # ── Valid enums for reward scoring (unchanged from v2) ────────────────────────
130
+ VALID_SENTIMENTS = {"positive", "negative", "neutral"}
131
+ VALID_CATEGORIES = {
132
+ "delivery_delay", "product_quality", "product_not_received",
133
+ "wrong_product", "seller_communication", "app_issue",
134
+ "price_value", "other", "none",
135
+ }
136
+ VALID_CHURN = {"low", "medium", "high"}
137
+ VALID_REPEAT = {"yes", "no", "maybe"}
138
+ EXTRACTION_FIELDS = [
139
+ "sentiment", "sentiment_score", "churn_risk", "delivery_issue",
140
+ "product_issue", "seller_issue", "main_complaint",
141
+ "complaint_category", "repeat_intent", "would_recommend",
142
+ ]
143
+
144
+ SYSTEM_PT = (
145
+ "Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
146
+ "Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
147
+ )
148
+
149
+ # ══════════════════════════════════════════════════════════════════════════════
150
+ # TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)
151
+ # ══════════════════════════════════════════════════════════════════════════════
152
+
153
+ # ── Core GRPO params ──────────────────────────────────────────────────────────
154
+ BATCH_SIZE = 4
155
+ GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)
156
+ # With G=4: steps = prompts × 4 / 4 = prompts per epoch
157
+ NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions
158
+ # MC-GRPO (2601.22582): G=4 works if noise is mitigated
159
+ SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias
160
+
161
+ # ── v3 CRITICAL FIXES ────────────────────────────────────────────────────────
162
+
163
+ # FIX 1: Temperature — prevent entropy collapse
164
+ # v2 had 0.8. All published GRPO papers use 1.0.
165
+ # Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance
166
+ TEMPERATURE = 1.0
167
+
168
+ # FIX 2: Completion length — remove the ceiling
169
+ # v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.
170
+ # Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient
171
+ MAX_COMPLETION_LENGTH = 4096
172
+
173
+ # FIX 3: Learning rate — more aggressive
174
+ # v2: clip_ratio=0 on all steps → updates were too small to matter
175
+ # Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.
176
+ # We go 2× since v2 showed zero clipping (model can absorb stronger push)
177
+ LEARNING_RATE = 2e-6
178
+
179
+ # FIX 4: β = 0 (no KL penalty)
180
+ # Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards
181
+ # v2 used implicit KL through default β — we explicitly disable it
182
+ BETA = 0.0
183
+
184
+ # ── Training schedule ─────────────────────────────────────────────────────────
185
+ NUM_EPOCHS = 1
186
+ MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed
187
+ # With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch
188
+ # MAX_STEPS=500 < 1 epoch — early stopping or manual extension
189
+
190
+ # ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────
191
+ EVAL_SPLIT_RATIO = 0.15
192
+ EVAL_STEPS = 10
193
+ EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway
194
+ EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains
195
+ SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again
196
+ SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)
197
+ WANDB_PROJECT = "tucano2-commerce"
198
+
199
+ # ── Eval callback ─────────────────────────────────────────────────────────────
200
+ EVAL_MAX_SAMPLES = 5
201
+ EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)
202
+ EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)
203
+
204
+ # ── Backend ───────────────────────────────────────────────────────────────────
205
+ USE_VLLM = False
206
+
207
+ # ── v3: Zero-advantage noise injection ────────────────────────────────────────
208
+ # Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training
209
+ # When all G completions get identical rewards, the advantage is undefined.
210
+ # Noise injection breaks ties without corrupting the signal.
211
+ ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups
212
+
213
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
214
+
215
+ # ── Version assertion ─────────────────────────────────────────────────────────
216
+ import trl as _trl
217
+ assert _trl.__version__ == "0.24.0", (
218
+ f"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\n"
219
+ "Verify that GRPOTrainer._generate() still exists before proceeding."
220
+ )
221
+
222
+ print("✓ v3 Config loaded")
223
+ print(f" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})")
224
+ print(f" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})")
225
+ print(f" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}")
226
+ print(f" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}")
227
+ print(f" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)")
228
+ print(f" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}")
229
+ print(f"✓ TRL {_trl.__version__} verified")
230
+
231
+ # ══════════════════════════════════════════════════════════════════════════════
232
+ # v3 VRAM BUDGET (L4 24GB)
233
+ # ══════════════════════════════════════════════════════════════════════════════
234
+ # Model (NF4): ~3.5 GB
235
+ # KV Cache (8192 seq): ~3.0 GB
236
+ # Activations: ~4.0 GB
237
+ # Optimizer states: ~3.0 GB
238
+ # Generations (4×4096): ~8.0 GB
239
+ # ─────────────────────────────────
240
+ # Estimated total: ~21.5 GB
241
+ # Headroom: ~2.5 GB
242
+ #
243
+ # If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.
244
+ # Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.
245
+ # ══════════════════════════════════════════════════════════════════════════════
246
+ ```
247
+
248
+ ---
249
+
250
+ ## Cell 4: Load SFT Adapter
251
+
252
+ ```python
253
+ print("Loading SFT adapter...")
254
+ model, tokenizer = FastLanguageModel.from_pretrained(
255
+ model_name=str(SFT_ADAPTER_DIR),
256
+ max_seq_length=MAX_SEQ_LENGTH,
257
+ load_in_4bit=True,
258
+ dtype=None,
259
+ )
260
+
261
+ if tokenizer.pad_token is None:
262
+ tokenizer.pad_token = tokenizer.eos_token
263
+
264
+ # Load chat template from base model (SFT adapter doesn't save it)
265
+ from transformers import AutoTokenizer
266
+ base_tok = AutoTokenizer.from_pretrained(MODEL_ID)
267
+ tokenizer.chat_template = base_tok.chat_template
268
+ del base_tok
269
+
270
+ # v2: Force KV cache — Unsloth patching may reset this
271
+ model.config.use_cache = True
272
+ model.generation_config.use_cache = True
273
+
274
+ print(f"✓ Model loaded on {model.device}")
275
+ print(f" use_cache: {model.config.use_cache}")
276
+ print(f" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M")
277
+ print(f" Chat template: {tokenizer.chat_template[:50]}...")
278
+ ```
279
+
280
+ ---
281
+
282
+ ## Cell 5: Single Inference Test
283
+
284
+ **Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?
285
+
286
+ ```python
287
+ FastLanguageModel.for_inference(model)
288
+
289
+ test_msgs = [
290
+ {"role": "system", "content": SYSTEM_PT},
291
+ {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
292
+ ]
293
+ text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)
294
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
295
+
296
+ t0 = time.time()
297
+ outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
298
+ elapsed = time.time() - t0
299
+
300
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
301
+ gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
302
+
303
+ print(f"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)")
304
+ print(f"Response length: {len(response)} chars, {gen_tokens} tokens")
305
+ print(f"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}") # v3: should NOT hit ceiling with 4096
306
+ print(f"closed_think: {'</think>' in response}")
307
+ print(f"\n{'='*60}")
308
+ print(response[:800])
309
+ ```
310
+
311
+ ---
312
+
313
+ ## Cell 5b: KV Cache Diagnostic
314
+
315
+ ```python
316
+ import time
317
+ FastLanguageModel.for_inference(model)
318
+
319
+ _kv_msgs = [{"role": "system", "content": SYSTEM_PT},
320
+ {"role": "user", "content": "Qual a categoria de reclamação mais frequente?"}]
321
+ _kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)
322
+ _kv_inputs = tokenizer(_kv_text, return_tensors="pt").to(model.device)
323
+
324
+ _token_times, _past, _generated = [], None, _kv_inputs["input_ids"]
325
+ with torch.no_grad():
326
+ for _step in range(50):
327
+ _t0 = time.time()
328
+ seq_len = _generated.shape[1]
329
+ if _past is None:
330
+ _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)
331
+ else:
332
+ _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)
333
+ _out = model(
334
+ input_ids=_generated[:, -1:] if _past else _generated,
335
+ position_ids=_position_ids,
336
+ attention_mask=torch.ones(1, seq_len, device=model.device),
337
+ past_key_values=_past,
338
+ use_cache=True,
339
+ return_dict=True,
340
+ )
341
+ _past = _out.past_key_values
342
+ _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
343
+ _generated = torch.cat([_generated, _next], dim=1)
344
+ _token_times.append(time.time() - _t0)
345
+
346
+ _ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)
347
+ print(f"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}")
348
+ print(f"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}")
349
+ print(f"Ratio last/first: {_ratio:.1f}x")
350
+ if _ratio < 3:
351
+ print("✓ KV cache is working correctly")
352
+ elif _ratio < 6:
353
+ print("⚠ KV cache may be degraded — check model.config.use_cache")
354
+ else:
355
+ print("✗ KV cache BROKEN — GRPO generation will be catastrophically slow.")
356
+
357
+ del _past, _generated, _kv_inputs, _token_times, _out
358
+ gc.collect()
359
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
360
+ ```
361
+
362
+ ---
363
+
364
+ ## Cell 6: Reward Functions v3
365
+
366
+ **v3 changes:**
367
+ - Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)
368
+ - Zero-advantage noise injection (Skywork-OR1, 2505.22312)
369
+ - Extraction reward redesigned for completion-length-friendly scoring
370
+
371
+ ```python
372
+ def strip_think(text: str) -> str:
373
+ """Remove <think>...</think> block, return the answer portion."""
374
+ return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
375
+
376
+
377
+ def has_think_block(text: str) -> bool:
378
+ """Check if text contains a non-empty <think> block."""
379
+ return bool(re.search(r"<think>.+</think>", text, flags=re.DOTALL))
380
+
381
+
382
+ def _classify_task_type(prompt_text: str) -> str:
383
+ """Classify prompt into task type by keywords."""
384
+ p = prompt_text.lower()
385
+ if "retorne um objeto json" in p or "extraia dados" in p:
386
+ return "extraction"
387
+ elif "notificação push" in p or "notificação de reengajamento" in p:
388
+ return "push"
389
+ elif "perfil do cliente" in p:
390
+ return "insights"
391
+ else:
392
+ return "sql_qa"
393
+
394
+
395
+ def _json_similarity(text: str) -> float:
396
+ """Rough heuristic: how JSON-like is this text? 0.0 to 1.0."""
397
+ text = text.strip()
398
+ if not text:
399
+ return 0.0
400
+ score = 0.0
401
+ if text.startswith("{") and text.endswith("}"):
402
+ score += 0.5
403
+ if '"' in text:
404
+ score += 0.2
405
+ if ":" in text:
406
+ score += 0.2
407
+ if "," in text:
408
+ score += 0.1
409
+ return min(score, 1.0)
410
+
411
+
412
+ def _string_similarity(a: str, b: str) -> float:
413
+ """Simple Jaccard-like similarity for short strings. 0.0 to 1.0."""
414
+ if not a or not b:
415
+ return 0.0
416
+ a_set = set(a.split())
417
+ b_set = set(b.split())
418
+ intersection = len(a_set & b_set)
419
+ union = len(a_set | b_set)
420
+ return intersection / union if union > 0 else 0.0
421
+
422
+
423
+ # ══════════════════════════════════════════════════════════════════════════════
424
+ # v3 STAGED REWARD DESIGN
425
+ # Reference: Reasoning-SQL (2503.23157) §3.2
426
+ #
427
+ # Each reward function scores THREE stages independently:
428
+ # Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?
429
+ # Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?
430
+ # Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?
431
+ #
432
+ # Format rewards converge first (easy to learn), which stabilizes training
433
+ # and enables the model to then learn harder task-specific skills.
434
+ # ══════════════════════════════════════════════════════════════════════════════
435
+
436
+
437
+ def reward_extraction(completion: str) -> float:
438
+ """Staged reward for structured extraction (max 1.0)."""
439
+ answer = strip_think(completion)
440
+
441
+ # ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────
442
+ r_format = 0.0
443
+ if has_think_block(completion):
444
+ r_format += 0.1 # Used reasoning
445
+
446
+ try:
447
+ data = json.loads(answer)
448
+ if isinstance(data, dict):
449
+ r_format += 0.1 # Valid JSON object
450
+ except (json.JSONDecodeError, TypeError):
451
+ r_format += 0.05 * _json_similarity(answer)
452
+ return min(r_format, 0.2)
453
+
454
+ if not isinstance(data, dict):
455
+ return min(r_format, 0.2)
456
+
457
+ # ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────
458
+ r_partial = 0.0
459
+
460
+ present = sum(1 for f in EXTRACTION_FIELDS if f in data)
461
+ r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))
462
+
463
+ type_checks = 0
464
+ type_total = 0
465
+ for field in EXTRACTION_FIELDS:
466
+ if field not in data:
467
+ continue
468
+ type_total += 1
469
+ val = data[field]
470
+ if field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
471
+ if isinstance(val, bool):
472
+ type_checks += 1
473
+ elif field in ("sentiment_score",):
474
+ if isinstance(val, (int, float)):
475
+ type_checks += 1
476
+ elif field in ("main_complaint", "sentiment", "complaint_category", "churn_risk", "repeat_intent"):
477
+ if isinstance(val, str):
478
+ type_checks += 1
479
+ if type_total > 0:
480
+ r_partial += 0.15 * (type_checks / type_total)
481
+
482
+ # ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────
483
+ r_task = 0.0
484
+ cat_checks = 0
485
+ cat_total = 0
486
+
487
+ checks = [
488
+ ("sentiment", lambda v: v in VALID_SENTIMENTS),
489
+ ("complaint_category", lambda v: v in VALID_CATEGORIES),
490
+ ("churn_risk", lambda v: v in VALID_CHURN),
491
+ ("repeat_intent", lambda v: v in VALID_REPEAT),
492
+ ("sentiment_score", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),
493
+ ]
494
+ for field, validator in checks:
495
+ cat_total += 1
496
+ if field in data and validator(data[field]):
497
+ cat_checks += 1
498
+
499
+ for bool_field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
500
+ cat_total += 1
501
+ if bool_field in data and isinstance(data[bool_field], bool):
502
+ cat_checks += 1
503
+
504
+ if cat_total > 0:
505
+ r_task += 0.35 * (cat_checks / cat_total)
506
+
507
+ if "main_complaint" in data and isinstance(data["main_complaint"], str):
508
+ complaint = data["main_complaint"].strip()
509
+ if len(complaint) > 10:
510
+ r_task += 0.15
511
+
512
+ return min(r_format + r_partial + r_task, 1.0)
513
+
514
+
515
+ def reward_sql_qa(completion: str) -> float:
516
+ """Staged reward for SQL Q&A (max 1.0)."""
517
+ answer = strip_think(completion)
518
+
519
+ # ── Stage 1: FORMAT (max 0.2)
520
+ r_format = 0.0
521
+ if has_think_block(completion):
522
+ r_format += 0.1
523
+ if "```" in answer or re.search(r"SELECT|FROM", answer, re.IGNORECASE):
524
+ r_format += 0.1
525
+
526
+ # ── Stage 2: PARTIAL (max 0.3)
527
+ r_partial = 0.0
528
+ sql_keywords = r"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING"
529
+ matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))
530
+ r_partial += min(0.15, 0.03 * matches)
531
+ numbers = re.findall(r"\d+(?:[.,]\d+)?", answer)
532
+ r_partial += min(0.15, 0.03 * len(numbers))
533
+
534
+ # ── Stage 3: TASK (max 0.5)
535
+ r_task = 0.0
536
+ length = len(answer)
537
+ if 50 <= length <= 600:
538
+ r_task += 0.25
539
+ elif length > 0:
540
+ r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)
541
+ explanation_markers = ["para ", "porque", "resultado", "mostra", "indica", "análise"]
542
+ expl_matches = sum(1 for w in explanation_markers if w in answer.lower())
543
+ r_task += min(0.25, 0.05 * expl_matches)
544
+
545
+ return min(r_format + r_partial + r_task, 1.0)
546
+
547
+
548
+ def reward_insights(completion: str) -> float:
549
+ """Staged reward for insights (max 1.0)."""
550
+ answer = strip_think(completion)
551
+
552
+ # ── Stage 1: FORMAT (max 0.2)
553
+ r_format = 0.0
554
+ if has_think_block(completion):
555
+ r_format += 0.1
556
+ structure_marks = len(re.findall(r"^[-•*]\s|^\d+[.)]\s|^#{1,3}\s", answer, re.MULTILINE))
557
+ r_format += min(0.1, 0.02 * structure_marks)
558
+
559
+ # ── Stage 2: PARTIAL (max 0.3)
560
+ r_partial = 0.0
561
+ length = len(answer)
562
+ if 100 <= length <= 1200:
563
+ r_partial += 0.15
564
+ elif length > 0:
565
+ r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)
566
+ pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto", answer, re.IGNORECASE)
567
+ r_partial += min(0.15, 0.01 * len(pt_markers))
568
+
569
+ # ── Stage 3: TASK (max 0.5)
570
+ r_task = 0.0
571
+ action_words = ["recomend", "implement", "melhor", "reduzir", "aumentar",
572
+ "priorizar", "investir", "otimizar", "estratégi", "suger",
573
+ "consider", "ação", "plano"]
574
+ matches = sum(1 for w in action_words if w in answer.lower())
575
+ r_task += min(0.3, 0.06 * matches)
576
+ data_refs = len(re.findall(r"\d+%|R\$\s*\d|média|percentual|comparad|taxa", answer, re.IGNORECASE))
577
+ r_task += min(0.2, 0.04 * data_refs)
578
+
579
+ return min(r_format + r_partial + r_task, 1.0)
580
+
581
+
582
+ def reward_push(completion: str) -> float:
583
+ """Staged reward for push notifications (max 1.0)."""
584
+ answer = strip_think(completion)
585
+ if not answer:
586
+ return 0.0
587
+
588
+ # ── Stage 1: FORMAT (max 0.2)
589
+ r_format = 0.0
590
+ if has_think_block(completion):
591
+ r_format += 0.05
592
+ length = len(answer)
593
+ if length <= 160:
594
+ r_format += 0.15
595
+ elif length <= 300:
596
+ r_format += 0.1
597
+ else:
598
+ r_format += 0.05
599
+
600
+ # ── Stage 2: PARTIAL (max 0.3)
601
+ r_partial = 0.0
602
+ pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua", answer, re.IGNORECASE)
603
+ r_partial += min(0.15, 0.02 * len(pt_markers))
604
+ if re.search(r"[!?]|[\U0001F600-\U0001F64F]|[\U0001F300-\U0001F5FF]", answer):
605
+ r_partial += 0.05
606
+ if len(answer.split()) >= 5:
607
+ r_partial += 0.1
608
+
609
+ # ── Stage 3: TASK (max 0.5)
610
+ r_task = 0.0
611
+ if length <= 120:
612
+ r_task += 0.25
613
+ else:
614
+ r_task += 0.25 * max(0, 1 - (length - 120) / 120)
615
+ generic_phrases = [
616
+ "olá! como podemos ajudar", "obrigado pela sua compra",
617
+ "seu pedido foi confirmado", "agradecemos sua preferência",
618
+ ]
619
+ max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)
620
+ r_task += 0.25 * (1 - max_similarity)
621
+
622
+ return min(r_format + r_partial + r_task, 1.0)
623
+
624
+
625
+ def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:
626
+ """
627
+ Master reward function v3: dispatches by task type + zero-advantage noise.
628
+ """
629
+ rewards = []
630
+ for completion, prompt in zip(completions, prompts):
631
+ if isinstance(completion, list):
632
+ comp_text = completion[-1]["content"] if completion else ""
633
+ else:
634
+ comp_text = str(completion)
635
+
636
+ if isinstance(prompt, list):
637
+ prompt_text = " ".join(m.get("content", "") for m in prompt)
638
+ else:
639
+ prompt_text = str(prompt)
640
+
641
+ task = _classify_task_type(prompt_text)
642
+
643
+ if task == "extraction":
644
+ rewards.append(reward_extraction(comp_text))
645
+ elif task == "sql_qa":
646
+ rewards.append(reward_sql_qa(comp_text))
647
+ elif task == "insights":
648
+ rewards.append(reward_insights(comp_text))
649
+ elif task == "push":
650
+ rewards.append(reward_push(comp_text))
651
+ else:
652
+ r = 0.15 if has_think_block(comp_text) else 0.0
653
+ r += 0.2 if comp_text.strip() else 0.0
654
+ rewards.append(r)
655
+
656
+ # ── v3: Zero-advantage noise injection ────────────────────────────────────
657
+ if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:
658
+ for i in range(0, len(rewards), NUM_GENERATIONS):
659
+ group = rewards[i:i+NUM_GENERATIONS]
660
+ if len(group) < 2:
661
+ continue
662
+ if max(group) - min(group) < 0.001:
663
+ for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):
664
+ rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)
665
+
666
+ return rewards
667
+
668
+
669
+ print("✓ v3 Reward functions defined (staged: format → partial → task)")
670
+ ```
671
+
672
+ ---
673
+
674
+ ## Cell 7: Reward Calibration
675
+
676
+ **Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38).
677
+
678
+ ```python
679
+ train_path = DATA_DIR / "pairs" / "train.jsonl"
680
+
681
+ by_type = {"extraction": [], "sql_qa": [], "insights": [], "push": []}
682
+ with open(train_path) as f:
683
+ for line in f:
684
+ row = json.loads(line)
685
+ convs = row["conversations"]
686
+ prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
687
+ if not prompt_msgs:
688
+ continue
689
+ user_text = " ".join(m["content"] for m in prompt_msgs if m["role"] == "user")
690
+ task = _classify_task_type(user_text)
691
+ by_type[task].append(prompt_msgs)
692
+
693
+ print(f"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}")
694
+
695
+ rng = random.Random(42)
696
+ cal_samples = []
697
+ for task_type in ["extraction", "extraction", "sql_qa", "sql_qa", "insights", "insights", "push", "push"]:
698
+ cal_samples.append(rng.choice(by_type[task_type]))
699
+
700
+ FastLanguageModel.for_inference(model)
701
+ print(f"\nReward calibration v3 ({len(cal_samples)} samples):")
702
+ print("-" * 70)
703
+
704
+ cal_rewards = []
705
+ for i, msgs in enumerate(cal_samples):
706
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
707
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
708
+ outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
709
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
710
+ gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
711
+
712
+ r = commerce_reward_fn([response], [text])[0]
713
+ cal_rewards.append(r)
714
+ hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
715
+ has_answer = "</think>" in response
716
+ answer_preview = strip_think(response)[:100] if has_answer else "[stuck in <think>]"
717
+ task = _classify_task_type(text)
718
+ print(f" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'⚠️ HIT' if hit_ceiling else 'ok':6s} | {answer_preview}")
719
+
720
+ print(f"\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}")
721
+ print(f"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70")
722
+ print(f"Variance > 0: {len(set(cal_rewards)) > 1}")
723
+ ```
724
+
725
+ ---
726
+
727
+ ## Cell 8: Dataset Preparation v3
728
+
729
+ ```python
730
+ from datasets import Dataset
731
+
732
+ def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,
733
+ general_mix=GENERAL_MIX_RATIO, seed=42):
734
+ rng = random.Random(seed)
735
+
736
+ train_pools = {}
737
+ eval_records = []
738
+ for task, pool in by_type.items():
739
+ shuffled = pool.copy()
740
+ rng.shuffle(shuffled)
741
+ n_eval = max(1, int(len(shuffled) * eval_ratio))
742
+ eval_records.extend(shuffled[:n_eval])
743
+ train_pools[task] = shuffled[n_eval:]
744
+
745
+ if n_prompts is None:
746
+ train_records = []
747
+ for task, pool in train_pools.items():
748
+ train_records.extend(pool)
749
+ rng.shuffle(train_records)
750
+ else:
751
+ targets = {
752
+ "extraction": int(n_prompts * 0.4),
753
+ "sql_qa": int(n_prompts * 0.4),
754
+ "insights": int(n_prompts * 0.1),
755
+ "push": int(n_prompts * 0.1),
756
+ }
757
+ train_records = []
758
+ for task, target_n in targets.items():
759
+ pool = train_pools[task]
760
+ n = min(target_n, len(pool))
761
+ train_records.extend(rng.sample(pool, n))
762
+ rng.shuffle(train_records)
763
+
764
+ general_path = DATA_DIR / "pairs" / "general_reasoning.jsonl"
765
+ if general_mix > 0 and general_path.exists():
766
+ general_records = []
767
+ with open(general_path) as f:
768
+ for line in f:
769
+ row = json.loads(line)
770
+ convs = row["conversations"]
771
+ prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
772
+ if prompt_msgs:
773
+ general_records.append(prompt_msgs)
774
+ n_general = int(len(train_records) * general_mix / (1 - general_mix))
775
+ n_general = min(n_general, len(general_records))
776
+ if n_general > 0:
777
+ train_records.extend(rng.sample(general_records, n_general))
778
+ rng.shuffle(train_records)
779
+ print(f" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)")
780
+ elif general_mix > 0:
781
+ print(f" ⚠️ general_reasoning.jsonl not found — skipping mix")
782
+
783
+ task_dist = {}
784
+ for record in train_records:
785
+ user_text = " ".join(m["content"] for m in record if m["role"] == "user")
786
+ task = _classify_task_type(user_text)
787
+ task_dist[task] = task_dist.get(task, 0) + 1
788
+
789
+ n_domain = len(train_records)
790
+ steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)
791
+
792
+ print(f"v3 Dataset split (eval_ratio={eval_ratio}):")
793
+ print(f" train : {n_domain} prompts")
794
+ print(f" eval : {len(eval_records)} prompts")
795
+ print(f" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}")
796
+ print(f" steps/epoch: {n_domain} × {NUM_GENERATIONS} / ({BATCH_SIZE} × {GRAD_ACCUM}) = {steps_per_epoch}")
797
+ print(f" MAX_STEPS={MAX_STEPS} → {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}")
798
+
799
+ train_ds = Dataset.from_list([{"prompt": msgs} for msgs in train_records])
800
+ eval_ds = Dataset.from_list([{"prompt": msgs} for msgs in eval_records])
801
+ return train_ds, eval_ds
802
+
803
+
804
+ train_dataset, eval_dataset = prepare_grpo_datasets_v3()
805
+ dataset = train_dataset
806
+ print(f"\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}")
807
+ ```
808
+
809
+ ---
810
+
811
+ ## Cell 9: Smoke Test
812
+
813
+ **Gate:** Runs 1 step without OOM at new completion length (4096).
814
+
815
+ ```python
816
+ from trl import GRPOConfig, GRPOTrainer
817
+
818
+ FastLanguageModel.for_training(model)
819
+
820
+ smoke_config = GRPOConfig(
821
+ output_dir=str(CHECKPOINT_DIR / "smoke"),
822
+ num_generations=NUM_GENERATIONS,
823
+ scale_rewards=SCALE_REWARDS,
824
+ max_completion_length=MAX_COMPLETION_LENGTH,
825
+ max_steps=1,
826
+ num_train_epochs=1,
827
+ temperature=TEMPERATURE,
828
+ per_device_train_batch_size=BATCH_SIZE,
829
+ gradient_accumulation_steps=1,
830
+ learning_rate=LEARNING_RATE,
831
+ fp16=False,
832
+ bf16=True,
833
+ logging_steps=1,
834
+ save_steps=999,
835
+ report_to="none",
836
+ max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
837
+ seed=42,
838
+ remove_unused_columns=False,
839
+ )
840
+
841
+ smoke_trainer = GRPOTrainer(
842
+ model=model,
843
+ reward_funcs=commerce_reward_fn,
844
+ args=smoke_config,
845
+ train_dataset=dataset,
846
+ tokenizer=tokenizer,
847
+ )
848
+
849
+ t0 = time.time()
850
+ smoke_trainer.train()
851
+ step_time = time.time() - t0
852
+
853
+ print(f"\n✓ Smoke test passed!")
854
+ print(f" Step time (grad_accum=1): {step_time:.0f}s")
855
+ print(f" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s")
856
+ print(f" VRAM peak: {torch.cuda.max_memory_allocated()/1e9:.1f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB")
857
+
858
+ vram_used = torch.cuda.max_memory_allocated() / 1e9
859
+ vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9
860
+ if vram_used > vram_total * 0.95:
861
+ print(f"\n⚠️ VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM")
862
+ print(f" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072")
863
+ print(f" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)")
864
+
865
+ del smoke_trainer
866
+ gc.collect(); torch.cuda.empty_cache()
867
+ ```
868
+
869
+ ---
870
+
871
+ ## Cell 10: Probe Run (3 steps)
872
+
873
+ ```python
874
+ FastLanguageModel.for_training(model)
875
+
876
+ probe_config = GRPOConfig(
877
+ output_dir=str(CHECKPOINT_DIR / "probe"),
878
+ num_generations=NUM_GENERATIONS,
879
+ scale_rewards=SCALE_REWARDS,
880
+ max_completion_length=MAX_COMPLETION_LENGTH,
881
+ max_steps=3,
882
+ temperature=TEMPERATURE,
883
+ num_train_epochs=NUM_EPOCHS,
884
+ per_device_train_batch_size=BATCH_SIZE,
885
+ gradient_accumulation_steps=GRAD_ACCUM,
886
+ learning_rate=LEARNING_RATE,
887
+ warmup_ratio=0.1,
888
+ lr_scheduler_type="cosine",
889
+ fp16=False,
890
+ bf16=True,
891
+ logging_steps=1,
892
+ disable_tqdm=True,
893
+ logging_first_step=True,
894
+ save_steps=999,
895
+ report_to="none",
896
+ max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
897
+ seed=42,
898
+ remove_unused_columns=False,
899
+ )
900
+
901
+ probe_trainer = GRPOTrainer(
902
+ model=model,
903
+ reward_funcs=commerce_reward_fn,
904
+ args=probe_config,
905
+ train_dataset=dataset,
906
+ tokenizer=tokenizer,
907
+ )
908
+
909
+ t0 = time.time()
910
+ result = probe_trainer.train()
911
+ elapsed = time.time() - t0
912
+
913
+ print(f"\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)")
914
+ print(f" Train loss: {result.training_loss:.6f}")
915
+ print(f" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h")
916
+
917
+ if abs(result.training_loss) < 1e-6:
918
+ print(" ⚠️ Loss is near-zero — reward variance may be insufficient")
919
+ else:
920
+ print(" ✓ Loss is non-zero — GRPO has gradient signal")
921
+
922
+ del probe_trainer
923
+ gc.collect(); torch.cuda.empty_cache()
924
+ ```
925
+
926
+ ---
927
+
928
+ ## Cell 11: Full Training Run v3
929
+
930
+ ```python
931
+ import wandb
932
+
933
+ _wandb_key = os.environ.get("WANDB_API_KEY", "").strip()
934
+ if not _wandb_key:
935
+ raise EnvironmentError("WANDB_API_KEY is not set.")
936
+ wandb.login(key=_wandb_key, relogin=True)
937
+ print(f"✓ W&B authenticated")
938
+ ```
939
+
940
+ ```python
941
+ import shutil
942
+ import torch
943
+ from transformers import TrainerCallback
944
+ from trl import GRPOConfig, GRPOTrainer
945
+
946
+ wandb.init(
947
+ project=WANDB_PROJECT,
948
+ name=f"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}",
949
+ config={
950
+ "model_id": MODEL_ID,
951
+ "version": "v3",
952
+ "temperature": TEMPERATURE,
953
+ "max_completion_length": MAX_COMPLETION_LENGTH,
954
+ "num_generations": NUM_GENERATIONS,
955
+ "learning_rate": LEARNING_RATE,
956
+ "beta": BETA,
957
+ "batch_size": BATCH_SIZE,
958
+ "grad_accum": GRAD_ACCUM,
959
+ "max_steps": MAX_STEPS,
960
+ "scale_rewards": SCALE_REWARDS,
961
+ "save_steps": SAVE_STEPS,
962
+ "eval_steps": EVAL_STEPS,
963
+ "eval_max_samples": EVAL_MAX_SAMPLES,
964
+ "eval_max_tokens": EVAL_MAX_TOKENS,
965
+ "eval_temperature": EVAL_TEMPERATURE,
966
+ "patience": EARLY_STOPPING_PATIENCE,
967
+ "delta": EARLY_STOPPING_DELTA,
968
+ "train_prompts": len(train_dataset),
969
+ "eval_prompts": len(eval_dataset),
970
+ "zero_adv_noise_std": ZERO_ADV_NOISE_STD,
971
+ "general_mix_ratio": GENERAL_MIX_RATIO,
972
+ "_ref_temperature": "Skywork-OR1 (2505.22312)",
973
+ "_ref_completion_length": "Dr. GRPO (2503.20783)",
974
+ "_ref_staged_rewards": "Reasoning-SQL (2503.23157)",
975
+ "_ref_zero_adv": "Skywork-OR1 (2505.22312)",
976
+ },
977
+ )
978
+ print(f"✓ W&B run: {wandb.run.url}")
979
+
980
+ FRESH = True
981
+ resume_from = None
982
+ if FRESH and CHECKPOINT_DIR.exists():
983
+ print("FRESH: deleting old checkpoints...")
984
+ shutil.rmtree(CHECKPOINT_DIR)
985
+ elif CHECKPOINT_DIR.exists():
986
+ checkpoints = sorted(
987
+ [d for d in CHECKPOINT_DIR.iterdir()
988
+ if d.is_dir() and d.name.startswith("checkpoint-")],
989
+ key=lambda d: int(d.name.split("-")[-1]),
990
+ )
991
+ if checkpoints:
992
+ resume_from = str(checkpoints[-1])
993
+ print(f"Resuming from: {resume_from}")
994
+
995
+
996
+ class UnslothGRPOTrainer(GRPOTrainer):
997
+ """Wraps generation with Unsloth for_inference()/for_training()."""
998
+ def _generate(self, prompts, images):
999
+ FastLanguageModel.for_inference(self.model)
1000
+ try:
1001
+ result = super()._generate(prompts, images)
1002
+ finally:
1003
+ FastLanguageModel.for_training(self.model)
1004
+ return result
1005
+
1006
+
1007
+ class EvalRewardCallback(TrainerCallback):
1008
+ """v3: deterministic eval, per-task breakdown, patience=15."""
1009
+ def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,
1010
+ delta=EARLY_STOPPING_DELTA):
1011
+ self.eval_records = eval_records
1012
+ self.reward_fn = reward_fn
1013
+ self.patience = patience
1014
+ self.delta = delta
1015
+ self.best_reward = -float("inf")
1016
+ self.no_improve_count = 0
1017
+
1018
+ def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
1019
+ if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:
1020
+ return control
1021
+ tokenizer = processing_class
1022
+ if tokenizer is None:
1023
+ print("[EvalRewardCallback] WARNING: tokenizer is None, skipping eval")
1024
+ return control
1025
+
1026
+ mean_reward, task_rewards = self._run_eval(model, tokenizer, args)
1027
+ improved = mean_reward > self.best_reward + self.delta
1028
+ status = "↑ improved" if improved else f"↔ no gain ({self.no_improve_count + 1}/{self.patience})"
1029
+
1030
+ log_dict = {
1031
+ "eval/mean_reward": mean_reward,
1032
+ "eval/best_reward": max(self.best_reward, mean_reward),
1033
+ "eval/no_improve_count": self.no_improve_count,
1034
+ }
1035
+ for task, rewards in task_rewards.items():
1036
+ if rewards:
1037
+ log_dict[f"eval/{task}_reward"] = sum(rewards) / len(rewards)
1038
+ wandb.log(log_dict, step=state.global_step)
1039
+
1040
+ print(f"\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}")
1041
+ for task, rewards in task_rewards.items():
1042
+ if rewards:
1043
+ print(f" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})")
1044
+
1045
+ if improved:
1046
+ self.best_reward = mean_reward
1047
+ self.no_improve_count = 0
1048
+ else:
1049
+ self.no_improve_count += 1
1050
+ if self.no_improve_count >= self.patience:
1051
+ print(f"[EarlyStopping] No improvement ≥ {self.delta} for {self.patience} consecutive evals. Halting.")
1052
+ wandb.log({"early_stop/step": state.global_step}, step=state.global_step)
1053
+ control.should_training_stop = True
1054
+ return control
1055
+
1056
+ def _run_eval(self, model, tokenizer, args):
1057
+ FastLanguageModel.for_inference(model)
1058
+ rewards = []
1059
+ task_rewards = {}
1060
+ subset = self.eval_records[:EVAL_MAX_SAMPLES]
1061
+ for record in subset:
1062
+ msgs = record["prompt"]
1063
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
1064
+ inputs = tokenizer(text, return_tensors="pt", truncation=True,
1065
+ max_length=args.max_prompt_length).to(model.device)
1066
+ with torch.no_grad():
1067
+ out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,
1068
+ temperature=EVAL_TEMPERATURE, do_sample=True)
1069
+ resp = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1070
+ r = self.reward_fn([resp], [text])[0]
1071
+ rewards.append(r)
1072
+ user_text = " ".join(m.get("content", "") for m in msgs if m.get("role") == "user")
1073
+ task = _classify_task_type(user_text)
1074
+ task_rewards.setdefault(task, []).append(r)
1075
+ FastLanguageModel.for_training(model)
1076
+ mean = sum(rewards) / len(rewards) if rewards else 0.0
1077
+ return mean, task_rewards
1078
+
1079
+
1080
+ class EntropyMonitorCallback(TrainerCallback):
1081
+ """v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 §4)."""
1082
+ def __init__(self):
1083
+ self.consecutive_ceiling_hits = 0
1084
+
1085
+ def on_log(self, args, state, control, logs=None, **kwargs):
1086
+ if not logs:
1087
+ return
1088
+ step = state.global_step
1089
+ monitor = {}
1090
+ comp_len = logs.get("completion_length", 0)
1091
+ if comp_len > 0:
1092
+ ratio = comp_len / MAX_COMPLETION_LENGTH
1093
+ monitor["monitor/completion_ratio"] = ratio
1094
+ if ratio > 0.95:
1095
+ self.consecutive_ceiling_hits += 1
1096
+ if self.consecutive_ceiling_hits >= 3:
1097
+ print(f"⚠️ Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.")
1098
+ else:
1099
+ self.consecutive_ceiling_hits = 0
1100
+ reward_std = logs.get("reward_std", logs.get("rewards/commerce_reward_fn/std", 0))
1101
+ if reward_std is not None:
1102
+ monitor["monitor/reward_std"] = reward_std
1103
+ if reward_std < 0.01:
1104
+ print(f"⚠️ Step {step}: reward_std={reward_std:.4f} — near-zero variance")
1105
+ clip_high = logs.get("clip_ratio/high_mean", 0)
1106
+ clip_low = logs.get("clip_ratio/low_mean", 0)
1107
+ if clip_high is not None and clip_low is not None:
1108
+ total_clip = clip_high + abs(clip_low)
1109
+ monitor["monitor/total_clip_ratio"] = total_clip
1110
+ if total_clip > 0.01 and step > 10:
1111
+ print(f"✓ Step {step}: clip_ratio={total_clip:.3f} — policy is updating")
1112
+ if monitor and wandb.run:
1113
+ wandb.log(monitor, step=step)
1114
+
1115
+
1116
+ FastLanguageModel.for_training(model)
1117
+
1118
+ grpo_config = GRPOConfig(
1119
+ output_dir=str(CHECKPOINT_DIR),
1120
+ num_generations=NUM_GENERATIONS,
1121
+ scale_rewards=SCALE_REWARDS,
1122
+ max_completion_length=MAX_COMPLETION_LENGTH,
1123
+ temperature=TEMPERATURE,
1124
+ max_steps=MAX_STEPS,
1125
+ num_train_epochs=NUM_EPOCHS,
1126
+ per_device_train_batch_size=BATCH_SIZE,
1127
+ gradient_accumulation_steps=GRAD_ACCUM,
1128
+ learning_rate=LEARNING_RATE,
1129
+ warmup_ratio=0.1,
1130
+ lr_scheduler_type="cosine",
1131
+ fp16=False,
1132
+ bf16=True,
1133
+ logging_steps=1,
1134
+ logging_first_step=True,
1135
+ disable_tqdm=True,
1136
+ save_steps=SAVE_STEPS,
1137
+ save_total_limit=SAVE_TOTAL_LIMIT,
1138
+ save_only_model=True,
1139
+ eval_steps=EVAL_STEPS,
1140
+ report_to="wandb",
1141
+ max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
1142
+ seed=42,
1143
+ remove_unused_columns=False,
1144
+ **({"use_vllm": True, "vllm_mode": "colocate",
1145
+ "vllm_enable_sleep_mode": True} if USE_VLLM else {}),
1146
+ )
1147
+
1148
+ eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)
1149
+ entropy_cb = EntropyMonitorCallback()
1150
+
1151
+ TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer
1152
+ trainer = TrainerClass(
1153
+ model=model,
1154
+ reward_funcs=commerce_reward_fn,
1155
+ args=grpo_config,
1156
+ train_dataset=train_dataset,
1157
+ processing_class=tokenizer,
1158
+ callbacks=[eval_cb, entropy_cb],
1159
+ )
1160
+
1161
+ print(f"{'='*70}")
1162
+ print(f"GRPO v3 Training — Ready to Launch")
1163
+ print(f"{'='*70}")
1164
+ print(f" Trainer: {TrainerClass.__name__}")
1165
+ print(f" Max steps: {MAX_STEPS}")
1166
+ print(f" Temperature: {TEMPERATURE} (v2 was 0.8)")
1167
+ print(f" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)")
1168
+ print(f" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)")
1169
+ print(f" Learning rate: {LEARNING_RATE} (v2 was 5e-7)")
1170
+ print(f" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})")
1171
+ print(f" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples × {EVAL_MAX_TOKENS} tok)")
1172
+ print(f" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)")
1173
+ print(f" Resume: {resume_from is not None}")
1174
+ print(f"{'='*70}")
1175
+
1176
+ t_start = time.time()
1177
+ result = trainer.train(resume_from_checkpoint=resume_from)
1178
+ elapsed = time.time() - t_start
1179
+
1180
+ wandb.log({
1181
+ "train/final_loss": result.training_loss,
1182
+ "train/duration_hours": elapsed / 3600,
1183
+ "train/total_steps": result.global_step,
1184
+ "eval/best_reward_final": eval_cb.best_reward,
1185
+ })
1186
+ wandb.finish()
1187
+
1188
+ print(f"\n{'='*70}")
1189
+ print(f"GRPO v3 Training Complete")
1190
+ print(f" Loss: {result.training_loss:.6f}")
1191
+ print(f" Steps: {result.global_step}")
1192
+ print(f" Duration: {elapsed/3600:.1f}h")
1193
+ print(f" Best eval R: {eval_cb.best_reward:.4f}")
1194
+ print(f" Trainer: {TrainerClass.__name__}")
1195
+ print(f"{'='*70}")
1196
+ ```
1197
+
1198
+ ---
1199
+
1200
+ ## Cell 12: Save Adapter
1201
+
1202
+ ```python
1203
+ GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
1204
+ model.save_pretrained(str(GRPO_ADAPTER_DIR))
1205
+ tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))
1206
+
1207
+ summary = {
1208
+ "model_id": MODEL_ID,
1209
+ "sft_adapter": str(SFT_ADAPTER_DIR),
1210
+ "method": "GRPO",
1211
+ "version": "v3",
1212
+ "train_loss": result.training_loss,
1213
+ "best_eval_reward": eval_cb.best_reward,
1214
+ "num_prompts": len(train_dataset),
1215
+ "num_generations": NUM_GENERATIONS,
1216
+ "scale_rewards": SCALE_REWARDS,
1217
+ "temperature": TEMPERATURE,
1218
+ "learning_rate": LEARNING_RATE,
1219
+ "beta": BETA,
1220
+ "max_completion_length": MAX_COMPLETION_LENGTH,
1221
+ "max_steps": MAX_STEPS,
1222
+ "actual_steps": result.global_step,
1223
+ "epochs": NUM_EPOCHS,
1224
+ "max_seq_length": MAX_SEQ_LENGTH,
1225
+ "duration_seconds": round(elapsed),
1226
+ "gpu": "L4",
1227
+ "platform": "vertex-ai-workbench",
1228
+ "v3_fixes": [
1229
+ "temperature=1.0 (Skywork-OR1)",
1230
+ "max_completion_length=4096 (Dr. GRPO)",
1231
+ "learning_rate=2e-6 (4x v2)",
1232
+ "beta=0.0 (Dr. GRPO)",
1233
+ "staged rewards (Reasoning-SQL)",
1234
+ "zero-advantage noise (Skywork-OR1)",
1235
+ "entropy monitoring callback",
1236
+ ],
1237
+ }
1238
+ with open(GRPO_ADAPTER_DIR / "training_summary.json", "w") as f:
1239
+ json.dump(summary, f, indent=2)
1240
+
1241
+ print(f"✓ Adapter saved to {GRPO_ADAPTER_DIR}")
1242
+ print(f" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}")
1243
+ ```
1244
+
1245
+ ---
1246
+
1247
+ ## Cell 13: Validation
1248
+
1249
+ ```python
1250
+ FastLanguageModel.for_inference(model)
1251
+
1252
+ system_msg = {"role": "system", "content": SYSTEM_PT}
1253
+
1254
+ test_prompts = [
1255
+ {"role": "user", "content": (
1256
+ "Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
1257
+ "nota=2/5 | status=delivered\ntítulo: decepcionado\n"
1258
+ "texto: Produto veio com defeito e o vendedor não respondeu.\n\n"
1259
+ "Retorne um objeto JSON com exatamente estas chaves:\n"
1260
+ "sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
1261
+ "seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
1262
+ )},
1263
+ {"role": "user", "content": (
1264
+ "Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
1265
+ "nota=5/5 | status=delivered\ntítulo: adorei!\n"
1266
+ "texto: Entrega rápida e produto exatamente como descrito. Recomendo!\n\n"
1267
+ "Retorne um objeto JSON com exatamente estas chaves:\n"
1268
+ "sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
1269
+ "seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
1270
+ )},
1271
+ {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
1272
+ {"role": "user", "content": "Analise a retenção de clientes afetados por product_quality."},
1273
+ {"role": "user", "content": (
1274
+ "Perfil do cliente:\n- Estado: MG\n- Valor do pedido: R$150\n"
1275
+ "- Reclamação: produto com defeito\n- Nota: 1.0/5\n\n"
1276
+ "Este cliente deve receber uma notificação de reengajamento?"
1277
+ )},
1278
+ {"role": "user", "content": "Compare a satisfação de clientes em SP vs RJ."},
1279
+ {"role": "user", "content": (
1280
+ "Crie uma notificação push de reengajamento para um cliente em SP "
1281
+ "que reclamou de atraso na entrega. Nota: 2/5."
1282
+ )},
1283
+ ]
1284
+
1285
+ print("=" * 70)
1286
+ print("GRPO v3 Validation")
1287
+ print("=" * 70)
1288
+
1289
+ v3_rewards = []
1290
+ for i, prompt in enumerate(test_prompts):
1291
+ messages = [system_msg, prompt]
1292
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1293
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
1294
+
1295
+ outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)
1296
+ gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
1297
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1298
+
1299
+ reward = commerce_reward_fn([response], [text])[0]
1300
+ v3_rewards.append(reward)
1301
+ answer = strip_think(response)
1302
+ task = _classify_task_type(prompt["content"])
1303
+ hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
1304
+
1305
+ print(f"\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---")
1306
+ print(f"Prompt: {prompt['content'][:80]}...")
1307
+ print(f"Answer: {answer[:400]}")
1308
+
1309
+ print(f"\n{'='*70}")
1310
+ print(f"v3 Validation Summary")
1311
+ print(f"{'='*70}")
1312
+ print(f" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}")
1313
+ print(f" Min: {min(v3_rewards):.3f}")
1314
+ print(f" Max: {max(v3_rewards):.3f}")
1315
+ print()
1316
+ print(f" Comparison to baselines:")
1317
+ print(f" SFT calibration (Cell 7): mean=0.38")
1318
+ print(f" GRPO v2 validation: mean=0.54")
1319
+ print(f" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}")
1320
+ v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100
1321
+ print(f" v3 vs v2: {v3_vs_v2:+.1f}%")
1322
+ ```