rtferraz commited on
Commit
b110818
·
verified ·
1 Parent(s): 734569e

Delete grpo_vertex_v3.md

Browse files

We're gonna use .ipynb version only

Files changed (1) hide show
  1. grpo_vertex_v3.md +0 -1322
grpo_vertex_v3.md DELETED
@@ -1,1322 +0,0 @@
1
- # Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)
2
-
3
- **v3 changes over v2 — grounded in published research:**
4
-
5
- | Change | v2 Value | v3 Value | Paper Reference |
6
- |--------|----------|----------|----------------|
7
- | Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |
8
- | Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |
9
- | Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |
10
- | Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |
11
- | β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |
12
- | Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |
13
- | Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |
14
- | Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |
15
- | Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |
16
- | Early stopping patience | 10 | **15** | More runway for longer completions |
17
- | Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |
18
- | Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |
19
- | General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |
20
-
21
- **Prerequisites:**
22
- - Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`
23
- - Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`
24
- - **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix
25
-
26
- **Hardware:** L4 (24GB), PyTorch kernel, bf16 supported
27
-
28
- ---
29
-
30
- ## Cell 1: Dependencies
31
-
32
- Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:
33
-
34
- ```python
35
- # Cell 1a — Nuke everything ML-related
36
- !pip uninstall -y torch torchvision torchaudio \
37
- unsloth unsloth-zoo \
38
- trl transformers peft accelerate \
39
- bitsandbytes vllm vllm-flash-attn \
40
- datasets tokenizers safetensors huggingface-hub \
41
- wandb xformers triton \
42
- cuda-bindings cuda-python \
43
- sentencepiece protobuf \
44
- 2>/dev/null
45
- ```
46
-
47
- ```python
48
- # Cell 1b — Kill any stragglers
49
- !pip freeze | grep -iE "torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate" | xargs pip uninstall -y 2>/dev/null
50
- ```
51
-
52
- ```python
53
- # Cell 1c — Purge cache
54
- !pip cache purge
55
- ```
56
-
57
- **⚠️ Restart kernel again**, then:
58
-
59
- ```python
60
- # Cell 1d — Clean install, correct order
61
- !pip install "unsloth"
62
- ```
63
-
64
- ```python
65
- # Cell 1e — Pin TRL (Unsloth may pull a different version)
66
- !pip install "trl==0.24.0" --no-deps
67
- ```
68
-
69
- ```python
70
- # Cell 1f — Extra deps
71
- !pip install "rich" "wandb"
72
- ```
73
-
74
- ---
75
-
76
- ## Cell 2: Hello World — GPU + Unsloth Verification
77
-
78
- ```python
79
- import torch
80
-
81
- print(f"CUDA available: {torch.cuda.is_available()}")
82
- print(f"GPU: {torch.cuda.get_device_name(0)}")
83
- print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
84
- print(f"bf16 support: {torch.cuda.is_bf16_supported()}")
85
-
86
- from unsloth import FastLanguageModel
87
- print("\n✓ Unsloth loaded successfully")
88
-
89
- import trl
90
- print(f"✓ TRL version: {trl.__version__}")
91
-
92
- import transformers
93
- print(f"✓ Transformers version: {transformers.__version__}")
94
- ```
95
-
96
- ---
97
-
98
- ## Cell 3: Config + Constants
99
-
100
- ```python
101
- import os
102
- os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
103
-
104
- import json
105
- import re
106
- import time
107
- import random
108
- import gc
109
- from pathlib import Path
110
-
111
- # ══════════════════════════════════════════════════════════════════════════════
112
- # v3 CONFIG — Every change is annotated with paper reference
113
- # ══════════════════════════════════════════════════════════════════════════════
114
-
115
- MODEL_ID = "Polygl0t/Tucano2-qwen-3.7B-Think"
116
- MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt
117
-
118
- # ── Paths ─────────────────────────────────────────────────────────────────────
119
- DATA_DIR = Path("/home/jupyter/tucano2/data")
120
- MODELS_DIR = Path("/home/jupyter/tucano2/models")
121
- SFT_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-sft"
122
- GRPO_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-grpo-v3" # v3: separate dir from v2
123
- CHECKPOINT_DIR = GRPO_ADAPTER_DIR / "checkpoints"
124
-
125
- # ── Training data ─────────────────────────────────────────────────────────────
126
- GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)
127
- GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)
128
-
129
- # ── Valid enums for reward scoring (unchanged from v2) ────────────────────────
130
- VALID_SENTIMENTS = {"positive", "negative", "neutral"}
131
- VALID_CATEGORIES = {
132
- "delivery_delay", "product_quality", "product_not_received",
133
- "wrong_product", "seller_communication", "app_issue",
134
- "price_value", "other", "none",
135
- }
136
- VALID_CHURN = {"low", "medium", "high"}
137
- VALID_REPEAT = {"yes", "no", "maybe"}
138
- EXTRACTION_FIELDS = [
139
- "sentiment", "sentiment_score", "churn_risk", "delivery_issue",
140
- "product_issue", "seller_issue", "main_complaint",
141
- "complaint_category", "repeat_intent", "would_recommend",
142
- ]
143
-
144
- SYSTEM_PT = (
145
- "Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
146
- "Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
147
- )
148
-
149
- # ══════════════════════════════════════════════════════════════════════════════
150
- # TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)
151
- # ══════════════════════════════════════════════════════════════════════════════
152
-
153
- # ── Core GRPO params ──────────────────────────────────────────────────────────
154
- BATCH_SIZE = 4
155
- GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)
156
- # With G=4: steps = prompts × 4 / 4 = prompts per epoch
157
- NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions
158
- # MC-GRPO (2601.22582): G=4 works if noise is mitigated
159
- SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias
160
-
161
- # ── v3 CRITICAL FIXES ────────────────────────────────────────────────────────
162
-
163
- # FIX 1: Temperature — prevent entropy collapse
164
- # v2 had 0.8. All published GRPO papers use 1.0.
165
- # Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance
166
- TEMPERATURE = 1.0
167
-
168
- # FIX 2: Completion length — remove the ceiling
169
- # v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.
170
- # Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient
171
- MAX_COMPLETION_LENGTH = 4096
172
-
173
- # FIX 3: Learning rate — more aggressive
174
- # v2: clip_ratio=0 on all steps → updates were too small to matter
175
- # Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.
176
- # We go 2× since v2 showed zero clipping (model can absorb stronger push)
177
- LEARNING_RATE = 2e-6
178
-
179
- # FIX 4: β = 0 (no KL penalty)
180
- # Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards
181
- # v2 used implicit KL through default β — we explicitly disable it
182
- BETA = 0.0
183
-
184
- # ── Training schedule ─────────────────────────────────────────────────────────
185
- NUM_EPOCHS = 1
186
- MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed
187
- # With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch
188
- # MAX_STEPS=500 < 1 epoch — early stopping or manual extension
189
-
190
- # ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────
191
- EVAL_SPLIT_RATIO = 0.15
192
- EVAL_STEPS = 10
193
- EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway
194
- EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains
195
- SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again
196
- SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)
197
- WANDB_PROJECT = "tucano2-commerce"
198
-
199
- # ── Eval callback ─────────────────────────────────────────────────────────────
200
- EVAL_MAX_SAMPLES = 5
201
- EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)
202
- EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)
203
-
204
- # ── Backend ───────────────────────────────────────────────────────────────────
205
- USE_VLLM = False
206
-
207
- # ── v3: Zero-advantage noise injection ────────────────────────────────────────
208
- # Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training
209
- # When all G completions get identical rewards, the advantage is undefined.
210
- # Noise injection breaks ties without corrupting the signal.
211
- ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups
212
-
213
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
214
-
215
- # ── Version assertion ─────────────────────────────────────────────────────────
216
- import trl as _trl
217
- assert _trl.__version__ == "0.24.0", (
218
- f"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\n"
219
- "Verify that GRPOTrainer._generate() still exists before proceeding."
220
- )
221
-
222
- print("✓ v3 Config loaded")
223
- print(f" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})")
224
- print(f" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})")
225
- print(f" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}")
226
- print(f" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}")
227
- print(f" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)")
228
- print(f" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}")
229
- print(f"✓ TRL {_trl.__version__} verified")
230
-
231
- # ══════════════════════════════════════════════════════════════════════════════
232
- # v3 VRAM BUDGET (L4 24GB)
233
- # ══════════════════════════════════════════════════════════════════════════════
234
- # Model (NF4): ~3.5 GB
235
- # KV Cache (8192 seq): ~3.0 GB
236
- # Activations: ~4.0 GB
237
- # Optimizer states: ~3.0 GB
238
- # Generations (4×4096): ~8.0 GB
239
- # ─────────────────────────────────
240
- # Estimated total: ~21.5 GB
241
- # Headroom: ~2.5 GB
242
- #
243
- # If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.
244
- # Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.
245
- # ══════════════════════════════════════════════════════════════════════════════
246
- ```
247
-
248
- ---
249
-
250
- ## Cell 4: Load SFT Adapter
251
-
252
- ```python
253
- print("Loading SFT adapter...")
254
- model, tokenizer = FastLanguageModel.from_pretrained(
255
- model_name=str(SFT_ADAPTER_DIR),
256
- max_seq_length=MAX_SEQ_LENGTH,
257
- load_in_4bit=True,
258
- dtype=None,
259
- )
260
-
261
- if tokenizer.pad_token is None:
262
- tokenizer.pad_token = tokenizer.eos_token
263
-
264
- # Load chat template from base model (SFT adapter doesn't save it)
265
- from transformers import AutoTokenizer
266
- base_tok = AutoTokenizer.from_pretrained(MODEL_ID)
267
- tokenizer.chat_template = base_tok.chat_template
268
- del base_tok
269
-
270
- # v2: Force KV cache — Unsloth patching may reset this
271
- model.config.use_cache = True
272
- model.generation_config.use_cache = True
273
-
274
- print(f"✓ Model loaded on {model.device}")
275
- print(f" use_cache: {model.config.use_cache}")
276
- print(f" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M")
277
- print(f" Chat template: {tokenizer.chat_template[:50]}...")
278
- ```
279
-
280
- ---
281
-
282
- ## Cell 5: Single Inference Test
283
-
284
- **Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?
285
-
286
- ```python
287
- FastLanguageModel.for_inference(model)
288
-
289
- test_msgs = [
290
- {"role": "system", "content": SYSTEM_PT},
291
- {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
292
- ]
293
- text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)
294
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
295
-
296
- t0 = time.time()
297
- outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
298
- elapsed = time.time() - t0
299
-
300
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
301
- gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
302
-
303
- print(f"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)")
304
- print(f"Response length: {len(response)} chars, {gen_tokens} tokens")
305
- print(f"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}") # v3: should NOT hit ceiling with 4096
306
- print(f"closed_think: {'</think>' in response}")
307
- print(f"\n{'='*60}")
308
- print(response[:800])
309
- ```
310
-
311
- ---
312
-
313
- ## Cell 5b: KV Cache Diagnostic
314
-
315
- ```python
316
- import time
317
- FastLanguageModel.for_inference(model)
318
-
319
- _kv_msgs = [{"role": "system", "content": SYSTEM_PT},
320
- {"role": "user", "content": "Qual a categoria de reclamação mais frequente?"}]
321
- _kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)
322
- _kv_inputs = tokenizer(_kv_text, return_tensors="pt").to(model.device)
323
-
324
- _token_times, _past, _generated = [], None, _kv_inputs["input_ids"]
325
- with torch.no_grad():
326
- for _step in range(50):
327
- _t0 = time.time()
328
- seq_len = _generated.shape[1]
329
- if _past is None:
330
- _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)
331
- else:
332
- _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)
333
- _out = model(
334
- input_ids=_generated[:, -1:] if _past else _generated,
335
- position_ids=_position_ids,
336
- attention_mask=torch.ones(1, seq_len, device=model.device),
337
- past_key_values=_past,
338
- use_cache=True,
339
- return_dict=True,
340
- )
341
- _past = _out.past_key_values
342
- _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
343
- _generated = torch.cat([_generated, _next], dim=1)
344
- _token_times.append(time.time() - _t0)
345
-
346
- _ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)
347
- print(f"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}")
348
- print(f"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}")
349
- print(f"Ratio last/first: {_ratio:.1f}x")
350
- if _ratio < 3:
351
- print("✓ KV cache is working correctly")
352
- elif _ratio < 6:
353
- print("⚠ KV cache may be degraded — check model.config.use_cache")
354
- else:
355
- print("✗ KV cache BROKEN — GRPO generation will be catastrophically slow.")
356
-
357
- del _past, _generated, _kv_inputs, _token_times, _out
358
- gc.collect()
359
- if torch.cuda.is_available(): torch.cuda.empty_cache()
360
- ```
361
-
362
- ---
363
-
364
- ## Cell 6: Reward Functions v3
365
-
366
- **v3 changes:**
367
- - Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)
368
- - Zero-advantage noise injection (Skywork-OR1, 2505.22312)
369
- - Extraction reward redesigned for completion-length-friendly scoring
370
-
371
- ```python
372
- def strip_think(text: str) -> str:
373
- """Remove <think>...</think> block, return the answer portion."""
374
- return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
375
-
376
-
377
- def has_think_block(text: str) -> bool:
378
- """Check if text contains a non-empty <think> block."""
379
- return bool(re.search(r"<think>.+</think>", text, flags=re.DOTALL))
380
-
381
-
382
- def _classify_task_type(prompt_text: str) -> str:
383
- """Classify prompt into task type by keywords."""
384
- p = prompt_text.lower()
385
- if "retorne um objeto json" in p or "extraia dados" in p:
386
- return "extraction"
387
- elif "notificação push" in p or "notificação de reengajamento" in p:
388
- return "push"
389
- elif "perfil do cliente" in p:
390
- return "insights"
391
- else:
392
- return "sql_qa"
393
-
394
-
395
- def _json_similarity(text: str) -> float:
396
- """Rough heuristic: how JSON-like is this text? 0.0 to 1.0."""
397
- text = text.strip()
398
- if not text:
399
- return 0.0
400
- score = 0.0
401
- if text.startswith("{") and text.endswith("}"):
402
- score += 0.5
403
- if '"' in text:
404
- score += 0.2
405
- if ":" in text:
406
- score += 0.2
407
- if "," in text:
408
- score += 0.1
409
- return min(score, 1.0)
410
-
411
-
412
- def _string_similarity(a: str, b: str) -> float:
413
- """Simple Jaccard-like similarity for short strings. 0.0 to 1.0."""
414
- if not a or not b:
415
- return 0.0
416
- a_set = set(a.split())
417
- b_set = set(b.split())
418
- intersection = len(a_set & b_set)
419
- union = len(a_set | b_set)
420
- return intersection / union if union > 0 else 0.0
421
-
422
-
423
- # ══════════════════════════════════════════════════════════════════════════════
424
- # v3 STAGED REWARD DESIGN
425
- # Reference: Reasoning-SQL (2503.23157) §3.2
426
- #
427
- # Each reward function scores THREE stages independently:
428
- # Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?
429
- # Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?
430
- # Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?
431
- #
432
- # Format rewards converge first (easy to learn), which stabilizes training
433
- # and enables the model to then learn harder task-specific skills.
434
- # ══════════════════════════════════════════════════════════════════════════════
435
-
436
-
437
- def reward_extraction(completion: str) -> float:
438
- """Staged reward for structured extraction (max 1.0)."""
439
- answer = strip_think(completion)
440
-
441
- # ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────
442
- r_format = 0.0
443
- if has_think_block(completion):
444
- r_format += 0.1 # Used reasoning
445
-
446
- try:
447
- data = json.loads(answer)
448
- if isinstance(data, dict):
449
- r_format += 0.1 # Valid JSON object
450
- except (json.JSONDecodeError, TypeError):
451
- r_format += 0.05 * _json_similarity(answer)
452
- return min(r_format, 0.2)
453
-
454
- if not isinstance(data, dict):
455
- return min(r_format, 0.2)
456
-
457
- # ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────
458
- r_partial = 0.0
459
-
460
- present = sum(1 for f in EXTRACTION_FIELDS if f in data)
461
- r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))
462
-
463
- type_checks = 0
464
- type_total = 0
465
- for field in EXTRACTION_FIELDS:
466
- if field not in data:
467
- continue
468
- type_total += 1
469
- val = data[field]
470
- if field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
471
- if isinstance(val, bool):
472
- type_checks += 1
473
- elif field in ("sentiment_score",):
474
- if isinstance(val, (int, float)):
475
- type_checks += 1
476
- elif field in ("main_complaint", "sentiment", "complaint_category", "churn_risk", "repeat_intent"):
477
- if isinstance(val, str):
478
- type_checks += 1
479
- if type_total > 0:
480
- r_partial += 0.15 * (type_checks / type_total)
481
-
482
- # ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────
483
- r_task = 0.0
484
- cat_checks = 0
485
- cat_total = 0
486
-
487
- checks = [
488
- ("sentiment", lambda v: v in VALID_SENTIMENTS),
489
- ("complaint_category", lambda v: v in VALID_CATEGORIES),
490
- ("churn_risk", lambda v: v in VALID_CHURN),
491
- ("repeat_intent", lambda v: v in VALID_REPEAT),
492
- ("sentiment_score", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),
493
- ]
494
- for field, validator in checks:
495
- cat_total += 1
496
- if field in data and validator(data[field]):
497
- cat_checks += 1
498
-
499
- for bool_field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
500
- cat_total += 1
501
- if bool_field in data and isinstance(data[bool_field], bool):
502
- cat_checks += 1
503
-
504
- if cat_total > 0:
505
- r_task += 0.35 * (cat_checks / cat_total)
506
-
507
- if "main_complaint" in data and isinstance(data["main_complaint"], str):
508
- complaint = data["main_complaint"].strip()
509
- if len(complaint) > 10:
510
- r_task += 0.15
511
-
512
- return min(r_format + r_partial + r_task, 1.0)
513
-
514
-
515
- def reward_sql_qa(completion: str) -> float:
516
- """Staged reward for SQL Q&A (max 1.0)."""
517
- answer = strip_think(completion)
518
-
519
- # ── Stage 1: FORMAT (max 0.2)
520
- r_format = 0.0
521
- if has_think_block(completion):
522
- r_format += 0.1
523
- if "```" in answer or re.search(r"SELECT|FROM", answer, re.IGNORECASE):
524
- r_format += 0.1
525
-
526
- # ── Stage 2: PARTIAL (max 0.3)
527
- r_partial = 0.0
528
- sql_keywords = r"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING"
529
- matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))
530
- r_partial += min(0.15, 0.03 * matches)
531
- numbers = re.findall(r"\d+(?:[.,]\d+)?", answer)
532
- r_partial += min(0.15, 0.03 * len(numbers))
533
-
534
- # ── Stage 3: TASK (max 0.5)
535
- r_task = 0.0
536
- length = len(answer)
537
- if 50 <= length <= 600:
538
- r_task += 0.25
539
- elif length > 0:
540
- r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)
541
- explanation_markers = ["para ", "porque", "resultado", "mostra", "indica", "análise"]
542
- expl_matches = sum(1 for w in explanation_markers if w in answer.lower())
543
- r_task += min(0.25, 0.05 * expl_matches)
544
-
545
- return min(r_format + r_partial + r_task, 1.0)
546
-
547
-
548
- def reward_insights(completion: str) -> float:
549
- """Staged reward for insights (max 1.0)."""
550
- answer = strip_think(completion)
551
-
552
- # ── Stage 1: FORMAT (max 0.2)
553
- r_format = 0.0
554
- if has_think_block(completion):
555
- r_format += 0.1
556
- structure_marks = len(re.findall(r"^[-•*]\s|^\d+[.)]\s|^#{1,3}\s", answer, re.MULTILINE))
557
- r_format += min(0.1, 0.02 * structure_marks)
558
-
559
- # ── Stage 2: PARTIAL (max 0.3)
560
- r_partial = 0.0
561
- length = len(answer)
562
- if 100 <= length <= 1200:
563
- r_partial += 0.15
564
- elif length > 0:
565
- r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)
566
- pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto", answer, re.IGNORECASE)
567
- r_partial += min(0.15, 0.01 * len(pt_markers))
568
-
569
- # ── Stage 3: TASK (max 0.5)
570
- r_task = 0.0
571
- action_words = ["recomend", "implement", "melhor", "reduzir", "aumentar",
572
- "priorizar", "investir", "otimizar", "estratégi", "suger",
573
- "consider", "ação", "plano"]
574
- matches = sum(1 for w in action_words if w in answer.lower())
575
- r_task += min(0.3, 0.06 * matches)
576
- data_refs = len(re.findall(r"\d+%|R\$\s*\d|média|percentual|comparad|taxa", answer, re.IGNORECASE))
577
- r_task += min(0.2, 0.04 * data_refs)
578
-
579
- return min(r_format + r_partial + r_task, 1.0)
580
-
581
-
582
- def reward_push(completion: str) -> float:
583
- """Staged reward for push notifications (max 1.0)."""
584
- answer = strip_think(completion)
585
- if not answer:
586
- return 0.0
587
-
588
- # ── Stage 1: FORMAT (max 0.2)
589
- r_format = 0.0
590
- if has_think_block(completion):
591
- r_format += 0.05
592
- length = len(answer)
593
- if length <= 160:
594
- r_format += 0.15
595
- elif length <= 300:
596
- r_format += 0.1
597
- else:
598
- r_format += 0.05
599
-
600
- # ── Stage 2: PARTIAL (max 0.3)
601
- r_partial = 0.0
602
- pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua", answer, re.IGNORECASE)
603
- r_partial += min(0.15, 0.02 * len(pt_markers))
604
- if re.search(r"[!?]|[\U0001F600-\U0001F64F]|[\U0001F300-\U0001F5FF]", answer):
605
- r_partial += 0.05
606
- if len(answer.split()) >= 5:
607
- r_partial += 0.1
608
-
609
- # ── Stage 3: TASK (max 0.5)
610
- r_task = 0.0
611
- if length <= 120:
612
- r_task += 0.25
613
- else:
614
- r_task += 0.25 * max(0, 1 - (length - 120) / 120)
615
- generic_phrases = [
616
- "olá! como podemos ajudar", "obrigado pela sua compra",
617
- "seu pedido foi confirmado", "agradecemos sua preferência",
618
- ]
619
- max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)
620
- r_task += 0.25 * (1 - max_similarity)
621
-
622
- return min(r_format + r_partial + r_task, 1.0)
623
-
624
-
625
- def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:
626
- """
627
- Master reward function v3: dispatches by task type + zero-advantage noise.
628
- """
629
- rewards = []
630
- for completion, prompt in zip(completions, prompts):
631
- if isinstance(completion, list):
632
- comp_text = completion[-1]["content"] if completion else ""
633
- else:
634
- comp_text = str(completion)
635
-
636
- if isinstance(prompt, list):
637
- prompt_text = " ".join(m.get("content", "") for m in prompt)
638
- else:
639
- prompt_text = str(prompt)
640
-
641
- task = _classify_task_type(prompt_text)
642
-
643
- if task == "extraction":
644
- rewards.append(reward_extraction(comp_text))
645
- elif task == "sql_qa":
646
- rewards.append(reward_sql_qa(comp_text))
647
- elif task == "insights":
648
- rewards.append(reward_insights(comp_text))
649
- elif task == "push":
650
- rewards.append(reward_push(comp_text))
651
- else:
652
- r = 0.15 if has_think_block(comp_text) else 0.0
653
- r += 0.2 if comp_text.strip() else 0.0
654
- rewards.append(r)
655
-
656
- # ── v3: Zero-advantage noise injection ────────────────────────────────────
657
- if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:
658
- for i in range(0, len(rewards), NUM_GENERATIONS):
659
- group = rewards[i:i+NUM_GENERATIONS]
660
- if len(group) < 2:
661
- continue
662
- if max(group) - min(group) < 0.001:
663
- for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):
664
- rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)
665
-
666
- return rewards
667
-
668
-
669
- print("✓ v3 Reward functions defined (staged: format → partial → task)")
670
- ```
671
-
672
- ---
673
-
674
- ## Cell 7: Reward Calibration
675
-
676
- **Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38).
677
-
678
- ```python
679
- train_path = DATA_DIR / "pairs" / "train.jsonl"
680
-
681
- by_type = {"extraction": [], "sql_qa": [], "insights": [], "push": []}
682
- with open(train_path) as f:
683
- for line in f:
684
- row = json.loads(line)
685
- convs = row["conversations"]
686
- prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
687
- if not prompt_msgs:
688
- continue
689
- user_text = " ".join(m["content"] for m in prompt_msgs if m["role"] == "user")
690
- task = _classify_task_type(user_text)
691
- by_type[task].append(prompt_msgs)
692
-
693
- print(f"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}")
694
-
695
- rng = random.Random(42)
696
- cal_samples = []
697
- for task_type in ["extraction", "extraction", "sql_qa", "sql_qa", "insights", "insights", "push", "push"]:
698
- cal_samples.append(rng.choice(by_type[task_type]))
699
-
700
- FastLanguageModel.for_inference(model)
701
- print(f"\nReward calibration v3 ({len(cal_samples)} samples):")
702
- print("-" * 70)
703
-
704
- cal_rewards = []
705
- for i, msgs in enumerate(cal_samples):
706
- text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
707
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
708
- outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
709
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
710
- gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
711
-
712
- r = commerce_reward_fn([response], [text])[0]
713
- cal_rewards.append(r)
714
- hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
715
- has_answer = "</think>" in response
716
- answer_preview = strip_think(response)[:100] if has_answer else "[stuck in <think>]"
717
- task = _classify_task_type(text)
718
- print(f" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'⚠️ HIT' if hit_ceiling else 'ok':6s} | {answer_preview}")
719
-
720
- print(f"\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}")
721
- print(f"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70")
722
- print(f"Variance > 0: {len(set(cal_rewards)) > 1}")
723
- ```
724
-
725
- ---
726
-
727
- ## Cell 8: Dataset Preparation v3
728
-
729
- ```python
730
- from datasets import Dataset
731
-
732
- def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,
733
- general_mix=GENERAL_MIX_RATIO, seed=42):
734
- rng = random.Random(seed)
735
-
736
- train_pools = {}
737
- eval_records = []
738
- for task, pool in by_type.items():
739
- shuffled = pool.copy()
740
- rng.shuffle(shuffled)
741
- n_eval = max(1, int(len(shuffled) * eval_ratio))
742
- eval_records.extend(shuffled[:n_eval])
743
- train_pools[task] = shuffled[n_eval:]
744
-
745
- if n_prompts is None:
746
- train_records = []
747
- for task, pool in train_pools.items():
748
- train_records.extend(pool)
749
- rng.shuffle(train_records)
750
- else:
751
- targets = {
752
- "extraction": int(n_prompts * 0.4),
753
- "sql_qa": int(n_prompts * 0.4),
754
- "insights": int(n_prompts * 0.1),
755
- "push": int(n_prompts * 0.1),
756
- }
757
- train_records = []
758
- for task, target_n in targets.items():
759
- pool = train_pools[task]
760
- n = min(target_n, len(pool))
761
- train_records.extend(rng.sample(pool, n))
762
- rng.shuffle(train_records)
763
-
764
- general_path = DATA_DIR / "pairs" / "general_reasoning.jsonl"
765
- if general_mix > 0 and general_path.exists():
766
- general_records = []
767
- with open(general_path) as f:
768
- for line in f:
769
- row = json.loads(line)
770
- convs = row["conversations"]
771
- prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
772
- if prompt_msgs:
773
- general_records.append(prompt_msgs)
774
- n_general = int(len(train_records) * general_mix / (1 - general_mix))
775
- n_general = min(n_general, len(general_records))
776
- if n_general > 0:
777
- train_records.extend(rng.sample(general_records, n_general))
778
- rng.shuffle(train_records)
779
- print(f" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)")
780
- elif general_mix > 0:
781
- print(f" ⚠️ general_reasoning.jsonl not found — skipping mix")
782
-
783
- task_dist = {}
784
- for record in train_records:
785
- user_text = " ".join(m["content"] for m in record if m["role"] == "user")
786
- task = _classify_task_type(user_text)
787
- task_dist[task] = task_dist.get(task, 0) + 1
788
-
789
- n_domain = len(train_records)
790
- steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)
791
-
792
- print(f"v3 Dataset split (eval_ratio={eval_ratio}):")
793
- print(f" train : {n_domain} prompts")
794
- print(f" eval : {len(eval_records)} prompts")
795
- print(f" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}")
796
- print(f" steps/epoch: {n_domain} × {NUM_GENERATIONS} / ({BATCH_SIZE} × {GRAD_ACCUM}) = {steps_per_epoch}")
797
- print(f" MAX_STEPS={MAX_STEPS} → {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}")
798
-
799
- train_ds = Dataset.from_list([{"prompt": msgs} for msgs in train_records])
800
- eval_ds = Dataset.from_list([{"prompt": msgs} for msgs in eval_records])
801
- return train_ds, eval_ds
802
-
803
-
804
- train_dataset, eval_dataset = prepare_grpo_datasets_v3()
805
- dataset = train_dataset
806
- print(f"\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}")
807
- ```
808
-
809
- ---
810
-
811
- ## Cell 9: Smoke Test
812
-
813
- **Gate:** Runs 1 step without OOM at new completion length (4096).
814
-
815
- ```python
816
- from trl import GRPOConfig, GRPOTrainer
817
-
818
- FastLanguageModel.for_training(model)
819
-
820
- smoke_config = GRPOConfig(
821
- output_dir=str(CHECKPOINT_DIR / "smoke"),
822
- num_generations=NUM_GENERATIONS,
823
- scale_rewards=SCALE_REWARDS,
824
- max_completion_length=MAX_COMPLETION_LENGTH,
825
- max_steps=1,
826
- num_train_epochs=1,
827
- temperature=TEMPERATURE,
828
- per_device_train_batch_size=BATCH_SIZE,
829
- gradient_accumulation_steps=1,
830
- learning_rate=LEARNING_RATE,
831
- fp16=False,
832
- bf16=True,
833
- logging_steps=1,
834
- save_steps=999,
835
- report_to="none",
836
- max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
837
- seed=42,
838
- remove_unused_columns=False,
839
- )
840
-
841
- smoke_trainer = GRPOTrainer(
842
- model=model,
843
- reward_funcs=commerce_reward_fn,
844
- args=smoke_config,
845
- train_dataset=dataset,
846
- tokenizer=tokenizer,
847
- )
848
-
849
- t0 = time.time()
850
- smoke_trainer.train()
851
- step_time = time.time() - t0
852
-
853
- print(f"\n✓ Smoke test passed!")
854
- print(f" Step time (grad_accum=1): {step_time:.0f}s")
855
- print(f" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s")
856
- print(f" VRAM peak: {torch.cuda.max_memory_allocated()/1e9:.1f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB")
857
-
858
- vram_used = torch.cuda.max_memory_allocated() / 1e9
859
- vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9
860
- if vram_used > vram_total * 0.95:
861
- print(f"\n⚠️ VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM")
862
- print(f" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072")
863
- print(f" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)")
864
-
865
- del smoke_trainer
866
- gc.collect(); torch.cuda.empty_cache()
867
- ```
868
-
869
- ---
870
-
871
- ## Cell 10: Probe Run (3 steps)
872
-
873
- ```python
874
- FastLanguageModel.for_training(model)
875
-
876
- probe_config = GRPOConfig(
877
- output_dir=str(CHECKPOINT_DIR / "probe"),
878
- num_generations=NUM_GENERATIONS,
879
- scale_rewards=SCALE_REWARDS,
880
- max_completion_length=MAX_COMPLETION_LENGTH,
881
- max_steps=3,
882
- temperature=TEMPERATURE,
883
- num_train_epochs=NUM_EPOCHS,
884
- per_device_train_batch_size=BATCH_SIZE,
885
- gradient_accumulation_steps=GRAD_ACCUM,
886
- learning_rate=LEARNING_RATE,
887
- warmup_ratio=0.1,
888
- lr_scheduler_type="cosine",
889
- fp16=False,
890
- bf16=True,
891
- logging_steps=1,
892
- disable_tqdm=True,
893
- logging_first_step=True,
894
- save_steps=999,
895
- report_to="none",
896
- max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
897
- seed=42,
898
- remove_unused_columns=False,
899
- )
900
-
901
- probe_trainer = GRPOTrainer(
902
- model=model,
903
- reward_funcs=commerce_reward_fn,
904
- args=probe_config,
905
- train_dataset=dataset,
906
- tokenizer=tokenizer,
907
- )
908
-
909
- t0 = time.time()
910
- result = probe_trainer.train()
911
- elapsed = time.time() - t0
912
-
913
- print(f"\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)")
914
- print(f" Train loss: {result.training_loss:.6f}")
915
- print(f" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h")
916
-
917
- if abs(result.training_loss) < 1e-6:
918
- print(" ⚠️ Loss is near-zero — reward variance may be insufficient")
919
- else:
920
- print(" ✓ Loss is non-zero — GRPO has gradient signal")
921
-
922
- del probe_trainer
923
- gc.collect(); torch.cuda.empty_cache()
924
- ```
925
-
926
- ---
927
-
928
- ## Cell 11: Full Training Run v3
929
-
930
- ```python
931
- import wandb
932
-
933
- _wandb_key = os.environ.get("WANDB_API_KEY", "").strip()
934
- if not _wandb_key:
935
- raise EnvironmentError("WANDB_API_KEY is not set.")
936
- wandb.login(key=_wandb_key, relogin=True)
937
- print(f"✓ W&B authenticated")
938
- ```
939
-
940
- ```python
941
- import shutil
942
- import torch
943
- from transformers import TrainerCallback
944
- from trl import GRPOConfig, GRPOTrainer
945
-
946
- wandb.init(
947
- project=WANDB_PROJECT,
948
- name=f"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}",
949
- config={
950
- "model_id": MODEL_ID,
951
- "version": "v3",
952
- "temperature": TEMPERATURE,
953
- "max_completion_length": MAX_COMPLETION_LENGTH,
954
- "num_generations": NUM_GENERATIONS,
955
- "learning_rate": LEARNING_RATE,
956
- "beta": BETA,
957
- "batch_size": BATCH_SIZE,
958
- "grad_accum": GRAD_ACCUM,
959
- "max_steps": MAX_STEPS,
960
- "scale_rewards": SCALE_REWARDS,
961
- "save_steps": SAVE_STEPS,
962
- "eval_steps": EVAL_STEPS,
963
- "eval_max_samples": EVAL_MAX_SAMPLES,
964
- "eval_max_tokens": EVAL_MAX_TOKENS,
965
- "eval_temperature": EVAL_TEMPERATURE,
966
- "patience": EARLY_STOPPING_PATIENCE,
967
- "delta": EARLY_STOPPING_DELTA,
968
- "train_prompts": len(train_dataset),
969
- "eval_prompts": len(eval_dataset),
970
- "zero_adv_noise_std": ZERO_ADV_NOISE_STD,
971
- "general_mix_ratio": GENERAL_MIX_RATIO,
972
- "_ref_temperature": "Skywork-OR1 (2505.22312)",
973
- "_ref_completion_length": "Dr. GRPO (2503.20783)",
974
- "_ref_staged_rewards": "Reasoning-SQL (2503.23157)",
975
- "_ref_zero_adv": "Skywork-OR1 (2505.22312)",
976
- },
977
- )
978
- print(f"✓ W&B run: {wandb.run.url}")
979
-
980
- FRESH = True
981
- resume_from = None
982
- if FRESH and CHECKPOINT_DIR.exists():
983
- print("FRESH: deleting old checkpoints...")
984
- shutil.rmtree(CHECKPOINT_DIR)
985
- elif CHECKPOINT_DIR.exists():
986
- checkpoints = sorted(
987
- [d for d in CHECKPOINT_DIR.iterdir()
988
- if d.is_dir() and d.name.startswith("checkpoint-")],
989
- key=lambda d: int(d.name.split("-")[-1]),
990
- )
991
- if checkpoints:
992
- resume_from = str(checkpoints[-1])
993
- print(f"Resuming from: {resume_from}")
994
-
995
-
996
- class UnslothGRPOTrainer(GRPOTrainer):
997
- """Wraps generation with Unsloth for_inference()/for_training()."""
998
- def _generate(self, prompts, images):
999
- FastLanguageModel.for_inference(self.model)
1000
- try:
1001
- result = super()._generate(prompts, images)
1002
- finally:
1003
- FastLanguageModel.for_training(self.model)
1004
- return result
1005
-
1006
-
1007
- class EvalRewardCallback(TrainerCallback):
1008
- """v3: deterministic eval, per-task breakdown, patience=15."""
1009
- def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,
1010
- delta=EARLY_STOPPING_DELTA):
1011
- self.eval_records = eval_records
1012
- self.reward_fn = reward_fn
1013
- self.patience = patience
1014
- self.delta = delta
1015
- self.best_reward = -float("inf")
1016
- self.no_improve_count = 0
1017
-
1018
- def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
1019
- if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:
1020
- return control
1021
- tokenizer = processing_class
1022
- if tokenizer is None:
1023
- print("[EvalRewardCallback] WARNING: tokenizer is None, skipping eval")
1024
- return control
1025
-
1026
- mean_reward, task_rewards = self._run_eval(model, tokenizer, args)
1027
- improved = mean_reward > self.best_reward + self.delta
1028
- status = "↑ improved" if improved else f"↔ no gain ({self.no_improve_count + 1}/{self.patience})"
1029
-
1030
- log_dict = {
1031
- "eval/mean_reward": mean_reward,
1032
- "eval/best_reward": max(self.best_reward, mean_reward),
1033
- "eval/no_improve_count": self.no_improve_count,
1034
- }
1035
- for task, rewards in task_rewards.items():
1036
- if rewards:
1037
- log_dict[f"eval/{task}_reward"] = sum(rewards) / len(rewards)
1038
- wandb.log(log_dict, step=state.global_step)
1039
-
1040
- print(f"\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}")
1041
- for task, rewards in task_rewards.items():
1042
- if rewards:
1043
- print(f" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})")
1044
-
1045
- if improved:
1046
- self.best_reward = mean_reward
1047
- self.no_improve_count = 0
1048
- else:
1049
- self.no_improve_count += 1
1050
- if self.no_improve_count >= self.patience:
1051
- print(f"[EarlyStopping] No improvement ≥ {self.delta} for {self.patience} consecutive evals. Halting.")
1052
- wandb.log({"early_stop/step": state.global_step}, step=state.global_step)
1053
- control.should_training_stop = True
1054
- return control
1055
-
1056
- def _run_eval(self, model, tokenizer, args):
1057
- FastLanguageModel.for_inference(model)
1058
- rewards = []
1059
- task_rewards = {}
1060
- subset = self.eval_records[:EVAL_MAX_SAMPLES]
1061
- for record in subset:
1062
- msgs = record["prompt"]
1063
- text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
1064
- inputs = tokenizer(text, return_tensors="pt", truncation=True,
1065
- max_length=args.max_prompt_length).to(model.device)
1066
- with torch.no_grad():
1067
- out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,
1068
- temperature=EVAL_TEMPERATURE, do_sample=True)
1069
- resp = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1070
- r = self.reward_fn([resp], [text])[0]
1071
- rewards.append(r)
1072
- user_text = " ".join(m.get("content", "") for m in msgs if m.get("role") == "user")
1073
- task = _classify_task_type(user_text)
1074
- task_rewards.setdefault(task, []).append(r)
1075
- FastLanguageModel.for_training(model)
1076
- mean = sum(rewards) / len(rewards) if rewards else 0.0
1077
- return mean, task_rewards
1078
-
1079
-
1080
- class EntropyMonitorCallback(TrainerCallback):
1081
- """v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 §4)."""
1082
- def __init__(self):
1083
- self.consecutive_ceiling_hits = 0
1084
-
1085
- def on_log(self, args, state, control, logs=None, **kwargs):
1086
- if not logs:
1087
- return
1088
- step = state.global_step
1089
- monitor = {}
1090
- comp_len = logs.get("completion_length", 0)
1091
- if comp_len > 0:
1092
- ratio = comp_len / MAX_COMPLETION_LENGTH
1093
- monitor["monitor/completion_ratio"] = ratio
1094
- if ratio > 0.95:
1095
- self.consecutive_ceiling_hits += 1
1096
- if self.consecutive_ceiling_hits >= 3:
1097
- print(f"⚠️ Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.")
1098
- else:
1099
- self.consecutive_ceiling_hits = 0
1100
- reward_std = logs.get("reward_std", logs.get("rewards/commerce_reward_fn/std", 0))
1101
- if reward_std is not None:
1102
- monitor["monitor/reward_std"] = reward_std
1103
- if reward_std < 0.01:
1104
- print(f"⚠️ Step {step}: reward_std={reward_std:.4f} — near-zero variance")
1105
- clip_high = logs.get("clip_ratio/high_mean", 0)
1106
- clip_low = logs.get("clip_ratio/low_mean", 0)
1107
- if clip_high is not None and clip_low is not None:
1108
- total_clip = clip_high + abs(clip_low)
1109
- monitor["monitor/total_clip_ratio"] = total_clip
1110
- if total_clip > 0.01 and step > 10:
1111
- print(f"✓ Step {step}: clip_ratio={total_clip:.3f} — policy is updating")
1112
- if monitor and wandb.run:
1113
- wandb.log(monitor, step=step)
1114
-
1115
-
1116
- FastLanguageModel.for_training(model)
1117
-
1118
- grpo_config = GRPOConfig(
1119
- output_dir=str(CHECKPOINT_DIR),
1120
- num_generations=NUM_GENERATIONS,
1121
- scale_rewards=SCALE_REWARDS,
1122
- max_completion_length=MAX_COMPLETION_LENGTH,
1123
- temperature=TEMPERATURE,
1124
- max_steps=MAX_STEPS,
1125
- num_train_epochs=NUM_EPOCHS,
1126
- per_device_train_batch_size=BATCH_SIZE,
1127
- gradient_accumulation_steps=GRAD_ACCUM,
1128
- learning_rate=LEARNING_RATE,
1129
- warmup_ratio=0.1,
1130
- lr_scheduler_type="cosine",
1131
- fp16=False,
1132
- bf16=True,
1133
- logging_steps=1,
1134
- logging_first_step=True,
1135
- disable_tqdm=True,
1136
- save_steps=SAVE_STEPS,
1137
- save_total_limit=SAVE_TOTAL_LIMIT,
1138
- save_only_model=True,
1139
- eval_steps=EVAL_STEPS,
1140
- report_to="wandb",
1141
- max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
1142
- seed=42,
1143
- remove_unused_columns=False,
1144
- **({"use_vllm": True, "vllm_mode": "colocate",
1145
- "vllm_enable_sleep_mode": True} if USE_VLLM else {}),
1146
- )
1147
-
1148
- eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)
1149
- entropy_cb = EntropyMonitorCallback()
1150
-
1151
- TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer
1152
- trainer = TrainerClass(
1153
- model=model,
1154
- reward_funcs=commerce_reward_fn,
1155
- args=grpo_config,
1156
- train_dataset=train_dataset,
1157
- processing_class=tokenizer,
1158
- callbacks=[eval_cb, entropy_cb],
1159
- )
1160
-
1161
- print(f"{'='*70}")
1162
- print(f"GRPO v3 Training — Ready to Launch")
1163
- print(f"{'='*70}")
1164
- print(f" Trainer: {TrainerClass.__name__}")
1165
- print(f" Max steps: {MAX_STEPS}")
1166
- print(f" Temperature: {TEMPERATURE} (v2 was 0.8)")
1167
- print(f" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)")
1168
- print(f" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)")
1169
- print(f" Learning rate: {LEARNING_RATE} (v2 was 5e-7)")
1170
- print(f" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})")
1171
- print(f" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples × {EVAL_MAX_TOKENS} tok)")
1172
- print(f" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)")
1173
- print(f" Resume: {resume_from is not None}")
1174
- print(f"{'='*70}")
1175
-
1176
- t_start = time.time()
1177
- result = trainer.train(resume_from_checkpoint=resume_from)
1178
- elapsed = time.time() - t_start
1179
-
1180
- wandb.log({
1181
- "train/final_loss": result.training_loss,
1182
- "train/duration_hours": elapsed / 3600,
1183
- "train/total_steps": result.global_step,
1184
- "eval/best_reward_final": eval_cb.best_reward,
1185
- })
1186
- wandb.finish()
1187
-
1188
- print(f"\n{'='*70}")
1189
- print(f"GRPO v3 Training Complete")
1190
- print(f" Loss: {result.training_loss:.6f}")
1191
- print(f" Steps: {result.global_step}")
1192
- print(f" Duration: {elapsed/3600:.1f}h")
1193
- print(f" Best eval R: {eval_cb.best_reward:.4f}")
1194
- print(f" Trainer: {TrainerClass.__name__}")
1195
- print(f"{'='*70}")
1196
- ```
1197
-
1198
- ---
1199
-
1200
- ## Cell 12: Save Adapter
1201
-
1202
- ```python
1203
- GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
1204
- model.save_pretrained(str(GRPO_ADAPTER_DIR))
1205
- tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))
1206
-
1207
- summary = {
1208
- "model_id": MODEL_ID,
1209
- "sft_adapter": str(SFT_ADAPTER_DIR),
1210
- "method": "GRPO",
1211
- "version": "v3",
1212
- "train_loss": result.training_loss,
1213
- "best_eval_reward": eval_cb.best_reward,
1214
- "num_prompts": len(train_dataset),
1215
- "num_generations": NUM_GENERATIONS,
1216
- "scale_rewards": SCALE_REWARDS,
1217
- "temperature": TEMPERATURE,
1218
- "learning_rate": LEARNING_RATE,
1219
- "beta": BETA,
1220
- "max_completion_length": MAX_COMPLETION_LENGTH,
1221
- "max_steps": MAX_STEPS,
1222
- "actual_steps": result.global_step,
1223
- "epochs": NUM_EPOCHS,
1224
- "max_seq_length": MAX_SEQ_LENGTH,
1225
- "duration_seconds": round(elapsed),
1226
- "gpu": "L4",
1227
- "platform": "vertex-ai-workbench",
1228
- "v3_fixes": [
1229
- "temperature=1.0 (Skywork-OR1)",
1230
- "max_completion_length=4096 (Dr. GRPO)",
1231
- "learning_rate=2e-6 (4x v2)",
1232
- "beta=0.0 (Dr. GRPO)",
1233
- "staged rewards (Reasoning-SQL)",
1234
- "zero-advantage noise (Skywork-OR1)",
1235
- "entropy monitoring callback",
1236
- ],
1237
- }
1238
- with open(GRPO_ADAPTER_DIR / "training_summary.json", "w") as f:
1239
- json.dump(summary, f, indent=2)
1240
-
1241
- print(f"✓ Adapter saved to {GRPO_ADAPTER_DIR}")
1242
- print(f" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}")
1243
- ```
1244
-
1245
- ---
1246
-
1247
- ## Cell 13: Validation
1248
-
1249
- ```python
1250
- FastLanguageModel.for_inference(model)
1251
-
1252
- system_msg = {"role": "system", "content": SYSTEM_PT}
1253
-
1254
- test_prompts = [
1255
- {"role": "user", "content": (
1256
- "Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
1257
- "nota=2/5 | status=delivered\ntítulo: decepcionado\n"
1258
- "texto: Produto veio com defeito e o vendedor não respondeu.\n\n"
1259
- "Retorne um objeto JSON com exatamente estas chaves:\n"
1260
- "sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
1261
- "seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
1262
- )},
1263
- {"role": "user", "content": (
1264
- "Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
1265
- "nota=5/5 | status=delivered\ntítulo: adorei!\n"
1266
- "texto: Entrega rápida e produto exatamente como descrito. Recomendo!\n\n"
1267
- "Retorne um objeto JSON com exatamente estas chaves:\n"
1268
- "sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
1269
- "seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
1270
- )},
1271
- {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
1272
- {"role": "user", "content": "Analise a retenção de clientes afetados por product_quality."},
1273
- {"role": "user", "content": (
1274
- "Perfil do cliente:\n- Estado: MG\n- Valor do pedido: R$150\n"
1275
- "- Reclamação: produto com defeito\n- Nota: 1.0/5\n\n"
1276
- "Este cliente deve receber uma notificação de reengajamento?"
1277
- )},
1278
- {"role": "user", "content": "Compare a satisfação de clientes em SP vs RJ."},
1279
- {"role": "user", "content": (
1280
- "Crie uma notificação push de reengajamento para um cliente em SP "
1281
- "que reclamou de atraso na entrega. Nota: 2/5."
1282
- )},
1283
- ]
1284
-
1285
- print("=" * 70)
1286
- print("GRPO v3 Validation")
1287
- print("=" * 70)
1288
-
1289
- v3_rewards = []
1290
- for i, prompt in enumerate(test_prompts):
1291
- messages = [system_msg, prompt]
1292
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1293
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
1294
-
1295
- outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)
1296
- gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
1297
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1298
-
1299
- reward = commerce_reward_fn([response], [text])[0]
1300
- v3_rewards.append(reward)
1301
- answer = strip_think(response)
1302
- task = _classify_task_type(prompt["content"])
1303
- hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
1304
-
1305
- print(f"\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---")
1306
- print(f"Prompt: {prompt['content'][:80]}...")
1307
- print(f"Answer: {answer[:400]}")
1308
-
1309
- print(f"\n{'='*70}")
1310
- print(f"v3 Validation Summary")
1311
- print(f"{'='*70}")
1312
- print(f" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}")
1313
- print(f" Min: {min(v3_rewards):.3f}")
1314
- print(f" Max: {max(v3_rewards):.3f}")
1315
- print()
1316
- print(f" Comparison to baselines:")
1317
- print(f" SFT calibration (Cell 7): mean=0.38")
1318
- print(f" GRPO v2 validation: mean=0.54")
1319
- print(f" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}")
1320
- v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100
1321
- print(f" v3 vs v2: {v3_vs_v2:+.1f}%")
1322
- ```