rtferraz commited on
Commit
042d2b9
·
verified ·
1 Parent(s): b47b36b

Create grpo_vertex_v2_ipynb.md

Browse files

Current training implementation, full run results below:
```json
{
"_runtime": 53739,
"_step": 1261,
"_timestamp": 1776946956.167131,
"_wandb.runtime": 53739,
"eval/best_reward_final": 0.125,
"profiling/Time taken: UnslothGRPOTrainer._calculate_rewards": 0.00957904900133144,
"profiling/Time taken: UnslothGRPOTrainer._prepare_inputs": 0.000010281000868417324,
"profiling/Time taken: UnslothGRPOTrainer.commerce_reward_fn": 0.008595035003963858,
"profiling/Time taken: UnslothGRPOTrainer.transformers.generate": 185.63205686000583,
"total_flos": 0,
"train/clip_ratio/high_max": 0,
"train/clip_ratio/high_mean": 0,
"train/clip_ratio/low_mean": 0,
"train/clip_ratio/low_min": 0,
"train/clip_ratio/region_mean": 0,
"train/completion_length": 2048,
"train/duration_hours": 14.927390168110527,
"train/epoch": 0.7,
"train/final_loss": -0.00020499005976965976,
"train/frac_reward_zero_std": 0,
"train/global_step": 210,
"train/grad_norm": 0.02984152734279633,
"train/kl": 0.004119975958019495,
"train/learning_rate": 1.2752757044047826e-7,
"train/loss": 0,
"train/reward": 0.2850000262260437,
"train/reward_std": 0.10184022039175034,
"train/rewards/commerce_reward_fn/mean": 0.2850000262260437,
"train/rewards/commerce_reward_fn/std": 0.10184022039175034,
"train/total_steps": 210,
"train_loss": -0.00020499005976965976,
"train_runtime": 53734.2335,
"train_samples_per_second": 0.045,
"train_steps_per_second": 0.006
}
```

Files changed (1) hide show
  1. grpo_vertex_v2_ipynb.md +1521 -0
grpo_vertex_v2_ipynb.md ADDED
@@ -0,0 +1,1521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tucano2 Commerce — GRPO Training v2 (Vertex AI Workbench / L4)
2
+
3
+ **v2 changes over v1:**
4
+ - `UnslothGRPOTrainer` — wraps generation with `for_inference()`/`for_training()` (~2-3× speedup)
5
+ - `processing_class=tokenizer` fix (was silently dropped in v1)
6
+ - `reward_extraction` normalized to max 1.0 (was 2.0 — biased gradient scale)
7
+ - KV cache diagnostic cell (Cell 5b)
8
+ - Eval callback capped to `EVAL_MAX_SAMPLES=10` + `EVAL_MAX_TOKENS=256` (vs 45 × 591s = 7.4h/eval)
9
+ - `UNSLOTH_COMPILE_DISABLE=1` — prevents kernel recompilation on mode switches
10
+ - Optional `USE_VLLM=True` path for 10-20× generation speedup
11
+
12
+ Ported from `tucano2_pipeline/06_rlvr.py` (Modal version).
13
+ Run incrementally: each cell is a gate — verify output before moving to next.
14
+
15
+ **Prerequisites:**
16
+ - Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`
17
+ - Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`
18
+
19
+ **Hardware:** L4 (24GB), PyTorch kernel, bf16 supported
20
+
21
+ ## Cell 1: Dependencies
22
+ Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:
23
+
24
+
25
+ ```python
26
+ # Cell 1 — Nuke everything ML-related
27
+ !pip uninstall -y torch torchvision torchaudio \
28
+ unsloth unsloth-zoo \
29
+ trl transformers peft accelerate \
30
+ bitsandbytes vllm vllm-flash-attn \
31
+ datasets tokenizers safetensors huggingface-hub \
32
+ wandb xformers triton \
33
+ cuda-bindings cuda-python \
34
+ sentencepiece protobuf \
35
+ 2>/dev/null
36
+ ```
37
+
38
+
39
+ ```python
40
+ # Cell 2 — Kill any stragglers (run twice if paranoid)
41
+ !pip freeze | grep -iE "torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate" | xargs pip uninstall -y 2>/dev/null
42
+ ```
43
+
44
+ Found existing installation: torch_c_dlpack_ext 0.1.5
45
+ Uninstalling torch_c_dlpack_ext-0.1.5:
46
+ Successfully uninstalled torch_c_dlpack_ext-0.1.5
47
+ Found existing installation: torchao 0.17.0
48
+ Uninstalling torchao-0.17.0:
49
+ Successfully uninstalled torchao-0.17.0
50
+
51
+
52
+
53
+ ```python
54
+ # Cell 3 — Purge cache
55
+ !pip cache purge
56
+ ```
57
+
58
+ Files removed: 918 (11059.0 MB)
59
+ Directories removed: 0
60
+
61
+
62
+ Restart kernel again, then:
63
+
64
+
65
+ ```python
66
+ # Cell 4 — Clean install, correct order
67
+ !pip install "unsloth"
68
+ ```
69
+
70
+
71
+ ```python
72
+ # Cell 5 — Pin TRL (Unsloth may pull a different version)
73
+ !pip install "trl==0.24.0" --no-deps
74
+ ```
75
+
76
+ Requirement already satisfied: trl==0.24.0 in /opt/conda/envs/pytorch/lib/python3.10/site-packages (0.24.0)
77
+
78
+
79
+
80
+ ```python
81
+ !pip install "rich" "wandb"
82
+
83
+ ```
84
+
85
+ ## Cell 2: Hello World — GPU + Unsloth verification
86
+
87
+
88
+ ```python
89
+ import torch
90
+
91
+ print(f"CUDA available: {torch.cuda.is_available()}")
92
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
93
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
94
+ print(f"bf16 support: {torch.cuda.is_bf16_supported()}")
95
+
96
+ from unsloth import FastLanguageModel
97
+ print("\n✓ Unsloth loaded successfully")
98
+
99
+ import trl
100
+ print(f"✓ TRL version: {trl.__version__}")
101
+
102
+ import transformers
103
+ print(f"✓ Transformers version: {transformers.__version__}")
104
+ ```
105
+
106
+ CUDA available: True
107
+ GPU: NVIDIA L4
108
+ VRAM: 23.6 GB
109
+ bf16 support: True
110
+ 🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
111
+ 🦥 Unsloth Zoo will now patch everything to make training faster!
112
+
113
+ ✓ Unsloth loaded successfully
114
+ ✓ TRL version: 0.24.0
115
+ ✓ Transformers version: 4.57.6
116
+
117
+
118
+ ## Cell 3: Config + Constants
119
+
120
+ All config from `06_rlvr.py` — no Modal dependencies.
121
+
122
+
123
+ ```python
124
+ import os
125
+ # ── v2: Disable Unsloth kernel recompilation between mode switches ─────────────
126
+ # Without this, for_inference() / for_training() trigger expensive Triton recompiles.
127
+ os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
128
+
129
+ import json
130
+ import re
131
+ import time
132
+ import random
133
+ import os
134
+ from pathlib import Path
135
+
136
+ # ── Config ──────────────────────���───────────────────────────────���─────────────
137
+ MODEL_ID = "Polygl0t/Tucano2-qwen-3.7B-Think"
138
+ MAX_SEQ_LENGTH = 4096 # context window (prompt + generation); model supports up to 32k
139
+
140
+ # Paths (Vertex AI Workbench — cwd is /home/jupyter/)
141
+ DATA_DIR = Path("/home/jupyter/tucano2/data")
142
+ MODELS_DIR = Path("/home/jupyter/tucano2/models")
143
+ SFT_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-sft"
144
+ GRPO_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-grpo"
145
+ CHECKPOINT_DIR = GRPO_ADAPTER_DIR / "checkpoints"
146
+
147
+ GRPO_PROMPTS = 300 # stratified subset size (120/120/30/30)
148
+
149
+ # Valid enum values for reward scoring
150
+ VALID_SENTIMENTS = {"positive", "negative", "neutral"}
151
+ VALID_CATEGORIES = {
152
+ "delivery_delay", "product_quality", "product_not_received",
153
+ "wrong_product", "seller_communication", "app_issue",
154
+ "price_value", "other", "none",
155
+ }
156
+ VALID_CHURN = {"low", "medium", "high"}
157
+ VALID_REPEAT = {"yes", "no", "maybe"}
158
+ EXTRACTION_FIELDS = [
159
+ "sentiment", "sentiment_score", "churn_risk", "delivery_issue",
160
+ "product_issue", "seller_issue", "main_complaint",
161
+ "complaint_category", "repeat_intent", "would_recommend",
162
+ ]
163
+
164
+ SYSTEM_PT = (
165
+ "Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
166
+ "Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
167
+ )
168
+
169
+ # Training params (validated on L4, optimized for H100)
170
+ BATCH_SIZE = 4 # L4 survived batch=4 (unsloth auto-adjusted); H100 has 80GB headroom
171
+ GRAD_ACCUM = 2 # effective batch = 4 * 2 = 8
172
+ MAX_COMPLETION_LENGTH = 2048 # non-negotiable: model needs room to think + answer
173
+ NUM_GENERATIONS = 8 # was 4 → more samples = higher chance of variance for GRPO
174
+ SCALE_REWARDS = False # Dr. GRPO fix: remove std normalization bias
175
+ LEARNING_RATE = 5e-7
176
+ NUM_EPOCHS = 2
177
+ TEMPERATURE = 0.8 # was 0.1 from model defaults
178
+ MAX_STEPS = 300 # -1 = full run (75 steps); set to e.g. 3 for probe
179
+
180
+
181
+ # ── ADR: Checkpoint + Eval + Early-Stop params ────────────────────────────────
182
+ EVAL_SPLIT_RATIO = 0.15 # 15% of each task bucket held out for eval
183
+ EVAL_STEPS = 10 # run EvalRewardCallback every N steps
184
+ EARLY_STOPPING_PATIENCE = 10 # was 3 — gives 100 steps of runway before halting
185
+ EARLY_STOPPING_DELTA = 0.01 # min reward gain to count as "improvement"
186
+ SAVE_STEPS = 15 # checkpoint every ~1h on L4 (Spot VM safety)
187
+ SAVE_TOTAL_LIMIT = 3 # auto-prune old checkpoints, keep last 3
188
+ WANDB_PROJECT = "tucano2-commerce"
189
+
190
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
191
+
192
+ print("✓ Config loaded")
193
+ print(f" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})")
194
+ print(f" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})")
195
+ print(f" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}")
196
+ print(f" Steps: {(GRPO_PROMPTS * NUM_GENERATIONS * NUM_EPOCHS) // (BATCH_SIZE * GRAD_ACCUM)} (full run)")
197
+ print(f" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}, eval_split={EVAL_SPLIT_RATIO}")
198
+
199
+ # ══════════════════════════════════════════════════════════════════════════════
200
+ # v2 Performance & Safety Flags
201
+ # ══════════════════════════════════════════════════════════════════════════════
202
+
203
+ # ── Generation backend ───────────────────────────────────────────────────────
204
+ USE_VLLM = False # Flip True if vllm is installed and VRAM allows
205
+ # Requires: pip install "vllm>=0.6.0"
206
+ # Enables vllm_mode="colocate" + vllm_enable_sleep_mode=True
207
+
208
+ # ── Eval callback safety caps ────────────────────────────────────────────────
209
+ # AT 591s/SAMPLE: 45 eval samples = 7.4h PER EVALUATION PASS — breaks training loop!
210
+ EVAL_MAX_SAMPLES = 5 # keep eval time manageable (~15 min per eval). it was 10 to cap to first N samples from eval_dataset
211
+ EVAL_MAX_TOKENS = 2048 # meaningful eval metric. it was 256 to keeps each eval pass < 15min
212
+
213
+ # ── TRL version assertion (UnslothGRPOTrainer overrides _generate) ───────────
214
+ import trl as _trl
215
+ assert _trl.__version__ == "0.24.0", (
216
+ f"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\n"
217
+ "Verify that GRPOTrainer._generate() still exists before proceeding."
218
+ )
219
+ print(f"✓ TRL {_trl.__version__} verified")
220
+
221
+ ```
222
+
223
+ ✓ Config loaded
224
+ SFT adapter: /home/jupyter/tucano2/models/tucano2-commerce-sft (exists: True)
225
+ Train data: /home/jupyter/tucano2/data/pairs/train.jsonl (exists: True)
226
+ Training: batch=4, grad_accum=2, eff_batch=8
227
+ Steps: 75 (full run)
228
+ ADR: save_steps=5, eval_steps=10, patience=10, eval_split=0.15
229
+ ✓ TRL 0.24.0 verified
230
+
231
+
232
+ ## Cell 4: Load SFT Adapter
233
+
234
+
235
+ ```python
236
+ print("Loading SFT adapter...")
237
+ model, tokenizer = FastLanguageModel.from_pretrained(
238
+ model_name=str(SFT_ADAPTER_DIR),
239
+ max_seq_length=MAX_SEQ_LENGTH,
240
+ load_in_4bit=True,
241
+ dtype=None,
242
+ )
243
+
244
+ if tokenizer.pad_token is None:
245
+ tokenizer.pad_token = tokenizer.eos_token
246
+
247
+ # Load chat template from base model (SFT adapter doesn't save it)
248
+ from transformers import AutoTokenizer
249
+ base_tok = AutoTokenizer.from_pretrained(MODEL_ID)
250
+ tokenizer.chat_template = base_tok.chat_template
251
+ del base_tok
252
+
253
+ # ── v2: Force KV cache — Unsloth patching may reset this ─────────────────────
254
+ model.config.use_cache = True
255
+ model.generation_config.use_cache = True
256
+
257
+ print(f"✓ Model loaded on {model.device}")
258
+ print(f" use_cache: {model.config.use_cache}")
259
+ print(f" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M")
260
+ print(f" Chat template: {tokenizer.chat_template[:50]}...")
261
+ ```
262
+
263
+ Loading SFT adapter...
264
+ ==((====))== Unsloth 2026.4.6: Fast Qwen3 patching. Transformers: 4.57.6. vLLM: 0.19.1.
265
+ \\ /| NVIDIA L4. Num GPUs = 1. Max memory: 21.951 GB. Platform: Linux.
266
+ O^O/ \_/ \ Torch: 2.10.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.6.0
267
+ \ / Bfloat16 = TRUE. FA [Xformers = 0.0.35. FA2 = False]
268
+ "-____-" Free license: http://github.com/unslothai/unsloth
269
+ Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
270
+
271
+
272
+
273
+ Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
274
+
275
+
276
+ Unsloth 2026.4.6 patched 36 layers with 0 QKV layers, 0 O layers and 0 MLP layers.
277
+
278
+
279
+ ✓ Model loaded on cuda:0
280
+ use_cache: True
281
+ Params: 1976M
282
+ Chat template: {#- Handle tool/function calling setup #}
283
+ {%- if t...
284
+
285
+
286
+ ## Cell 5: Single Inference Test
287
+
288
+ **Gate:** Does the model close `</think>` and produce an answer within 2048 tokens?
289
+
290
+
291
+ ```python
292
+ FastLanguageModel.for_inference(model)
293
+
294
+ test_msgs = [
295
+ {"role": "system", "content": SYSTEM_PT},
296
+ {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
297
+ ]
298
+ text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)
299
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
300
+
301
+ t0 = time.time()
302
+ outputs = model.generate(**inputs, max_new_tokens=2048, temperature=0.7, do_sample=True)
303
+ elapsed = time.time() - t0
304
+
305
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
306
+
307
+ print(f"Generation time: {elapsed:.1f}s")
308
+ print(f"Response length: {len(response)} chars")
309
+ print(f"closed_think: {'</think>' in response}")
310
+ print(f"\n{'='*60}")
311
+ print(response[:800])
312
+ ```
313
+
314
+ Generation time: 66.8s
315
+ Response length: 3945 chars
316
+ closed_think: True
317
+
318
+ ============================================================
319
+ <think>
320
+ Para responder à pergunta sobre quais tipos do conjunto geral têm uma classificação melhor que o outro, vamos analisar os dados fornecidos passo a passo:
321
+
322
+ ### Dados Disponíveis:
323
+ 1. **Classificação Geral:** 4.83/5 estrelas (N = 2750)
324
+ 2. **Reclamações Gerais Não Específicas:** 4.68/5 estrelas (2750 participantes)
325
+ 3. **Problemas com Remessa/Envio:** 4.51/5 estrelas (1733 participantes)
326
+ 4. **Não Respondido Rapidamente Após Contato Inicial:** 4.63/5 estrelas (1263 participantes)
327
+ 5. **Sem Reclamação** (5.00/5 estrelas 1090 participantes)
328
+ 6. **Outras Reivindicações (múltiplas):** 4.76/5 estrelas (919 participantes)
329
+ 7. **Fora Do Prazo De Garantia Ou Válvula:** 4.49/5 Estrelas (746 Participantes)
330
+ 8. **Erro Na Entrega ou Extraído Inadvertidamente:** 4.43/5 ESTRELAS (437 PARTICIPANTES)
331
+ 9. **D
332
+
333
+
334
+ ## Cell 5b: KV Cache Diagnostic
335
+
336
+ **Gate:** Ratio < 3× → KV cache is working. Ratio > 5× → cache is broken (O(n²) full attention recompute at every token — training generation will be catastrophically slow).
337
+
338
+ If broken, check `model.config.use_cache` and try `UNSLOTH_COMPILE_DISABLE=1`.
339
+
340
+
341
+
342
+ ```python
343
+ # ── KV Cache Diagnostic ───────────────────────────────────────────────────────
344
+ # Tests whether attention past_key_values are actually being used.
345
+ # O(n²) failure: token time grows linearly with sequence length → ratio >> 1.
346
+ import time
347
+ FastLanguageModel.for_inference(model)
348
+
349
+ _kv_msgs = [{"role": "system", "content": SYSTEM_PT},
350
+ {"role": "user", "content": "Qual a categoria de reclamação mais frequente?"}]
351
+ _kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)
352
+ _kv_inputs = tokenizer(_kv_text, return_tensors="pt").to(model.device)
353
+
354
+ _token_times, _past, _generated = [], None, _kv_inputs["input_ids"]
355
+ with torch.no_grad():
356
+ for _step in range(50):
357
+ _t0 = time.time()
358
+
359
+ # Calculate current sequence length
360
+ seq_len = _generated.shape[1]
361
+
362
+ # Manually construct position_ids
363
+ if _past is None:
364
+ # First pass: we need positions for the whole prompt (e.g., [0, 1, 2, ...])
365
+ _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)
366
+ else:
367
+ # Subsequent passes: we only need the position for the single new token
368
+ _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)
369
+
370
+ _out = model(
371
+ input_ids=_generated[:, -1:] if _past else _generated,
372
+ position_ids=_position_ids, # <--- The missing argument!
373
+ attention_mask=torch.ones(1, seq_len, device=model.device),
374
+ past_key_values=_past,
375
+ use_cache=True,
376
+ return_dict=True, # <--- Forces it to return an object instead of a tuple!
377
+ )
378
+ _past = _out.past_key_values
379
+ _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
380
+ _generated = torch.cat([_generated, _next], dim=1)
381
+ _token_times.append(time.time() - _t0)
382
+
383
+ _ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)
384
+ print(f"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}")
385
+ print(f"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}")
386
+ print(f"Ratio last/first: {_ratio:.1f}x")
387
+ if _ratio < 3:
388
+ print("✓ KV cache is working correctly")
389
+ elif _ratio < 6:
390
+ print("⚠ KV cache may be degraded — check model.config.use_cache")
391
+ else:
392
+ print("✗ KV cache BROKEN — O(n²) recompute. GRPO generation will be catastrophically slow.")
393
+ print(" Try: model.config.use_cache = True; model.generation_config.use_cache = True")
394
+
395
+ # Clean up cache vars
396
+ del _past, _generated, _kv_inputs, _token_times, _out
397
+ import gc; gc.collect()
398
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
399
+
400
+ ```
401
+
402
+ First 5 tok : ['277ms', '95ms', '87ms', '86ms', '86ms']
403
+ Last 5 tok : ['88ms', '88ms', '86ms', '86ms', '86ms']
404
+ Ratio last/first: 0.7x
405
+ ✓ KV cache is working correctly
406
+
407
+
408
+ ## Cell 6: Reward Functions
409
+
410
+ Copied verbatim from `06_rlvr.py`. Pure Python — no external dependencies.
411
+
412
+
413
+ ```python
414
+ def strip_think(text: str) -> str:
415
+ """Remove <think>...</think> block, return the answer portion."""
416
+ return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
417
+
418
+
419
+ def has_think_block(text: str) -> bool:
420
+ """Check if text contains a non-empty <think> block."""
421
+ return bool(re.search(r"<think>.+</think>", text, flags=re.DOTALL))
422
+
423
+
424
+ def _classify_task_type(prompt_text: str) -> str:
425
+ """Classify prompt into task type by keywords."""
426
+ p = prompt_text.lower()
427
+ if "retorne um objeto json" in p or "extraia dados" in p:
428
+ return "extraction"
429
+ elif "notificação push" in p or "notificação de reengajamento" in p:
430
+ return "push"
431
+ elif "perfil do cliente" in p:
432
+ return "insights"
433
+ else:
434
+ return "sql_qa"
435
+
436
+
437
+ def reward_extraction(completion: str) -> float:
438
+ """Continuous reward for extraction tasks (max 1.0, normalized)."""
439
+ score = 0.0
440
+ answer = strip_think(completion)
441
+
442
+ # +0.1 for <think> block (binary but small weight)
443
+ if has_think_block(completion):
444
+ score += 0.1
445
+
446
+ # JSON parsing: partial credit
447
+ try:
448
+ data = json.loads(answer)
449
+ except (json.JSONDecodeError, TypeError):
450
+ # Partial credit for JSON-like structure
451
+ score += 0.05 * _json_similarity(answer)
452
+ return score
453
+
454
+ if not isinstance(data, dict):
455
+ score += 0.1 # at least it's valid JSON
456
+ return score
457
+
458
+ score += 0.3 # valid JSON object
459
+
460
+ # Schema completeness: fractional credit per field
461
+ present = sum(1 for f in EXTRACTION_FIELDS if f in data)
462
+ score += 0.3 * (present / len(EXTRACTION_FIELDS))
463
+
464
+ # Categorical correctness: fractional per field
465
+ cat_checks = 0
466
+ cat_total = 0
467
+
468
+ checks = [
469
+ ("sentiment", lambda v: v in VALID_SENTIMENTS),
470
+ ("complaint_category", lambda v: v in VALID_CATEGORIES),
471
+ ("churn_risk", lambda v: v in VALID_CHURN),
472
+ ("repeat_intent", lambda v: v in VALID_REPEAT),
473
+ ("sentiment_score", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),
474
+ ]
475
+ for field, validator in checks:
476
+ cat_total += 1
477
+ if field in data and validator(data[field]):
478
+ cat_checks += 1
479
+
480
+ for bool_field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
481
+ cat_total += 1
482
+ if bool_field in data and isinstance(data[bool_field], bool):
483
+ cat_checks += 1
484
+
485
+ if cat_total > 0:
486
+ score += 0.3 * (cat_checks / cat_total)
487
+
488
+ return min(score, 1.0) # cap at 1.0
489
+
490
+
491
+ def _json_similarity(text: str) -> float:
492
+ """Rough heuristic: how JSON-like is this text? 0.0 to 1.0."""
493
+ text = text.strip()
494
+ if not text:
495
+ return 0.0
496
+ score = 0.0
497
+ if text.startswith("{") and text.endswith("}"):
498
+ score += 0.5
499
+ if '"' in text:
500
+ score += 0.2
501
+ if ":" in text:
502
+ score += 0.2
503
+ if "," in text:
504
+ score += 0.1
505
+ return min(score, 1.0)
506
+
507
+
508
+ def reward_sql_qa(completion: str) -> float:
509
+ """Continuous reward for SQL Q&A (max 1.0)."""
510
+ score = 0.0
511
+ answer = strip_think(completion)
512
+
513
+ if has_think_block(completion):
514
+ score += 0.1
515
+
516
+ # Numerical content: more numbers = more specific answer
517
+ numbers = re.findall(r"\d+(?:[.,]\d+)?", answer)
518
+ score += min(0.4, 0.1 * len(numbers)) # up to 0.4 for multiple numbers
519
+
520
+ # Length: optimal is 100-400 chars. Penalize too short or too long.
521
+ length = len(answer)
522
+ if 50 <= length <= 500:
523
+ score += 0.3
524
+ elif length > 0:
525
+ score += 0.3 * max(0, 1 - abs(length - 275) / 225) # linear falloff
526
+
527
+ # SQL keywords: evidence of actual query analysis
528
+ sql_keywords = r"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN"
529
+ matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))
530
+ score += min(0.2, 0.05 * matches)
531
+
532
+ return min(score, 1.0)
533
+
534
+
535
+ def reward_insights(completion: str) -> float:
536
+ """Continuous reward for insights (max 1.0)."""
537
+ score = 0.0
538
+ answer = strip_think(completion)
539
+
540
+ if has_think_block(completion):
541
+ score += 0.1
542
+
543
+ # Actionable language: count matches, not binary
544
+ action_words = ["recomend", "implement", "melhor", "reduzir", "aumentar",
545
+ "priorizar", "investir", "otimizar", "estratégi"]
546
+ matches = sum(1 for w in action_words if w in answer.lower())
547
+ score += min(0.5, 0.1 * matches)
548
+
549
+ # Length: 100-1000 chars optimal
550
+ length = len(answer)
551
+ if 100 <= length <= 1000:
552
+ score += 0.3
553
+ elif length > 0:
554
+ score += 0.3 * max(0, 1 - abs(length - 550) / 450)
555
+
556
+ # Structure: bullet points, numbered lists = organized thinking
557
+ structure_marks = len(re.findall(r"^[-•*]\s|^\d+[.)]\s", answer, re.MULTILINE))
558
+ score += min(0.1, 0.02 * structure_marks)
559
+
560
+ return min(score, 1.0)
561
+
562
+
563
+ def reward_push(completion: str) -> float:
564
+ """Continuous reward for push notifications (max 1.0)."""
565
+ score = 0.0
566
+ answer = strip_think(completion)
567
+
568
+ if not answer:
569
+ return 0.0
570
+
571
+ # Length: ≤120 chars gets full credit, linear penalty above
572
+ length = len(answer)
573
+ if length <= 120:
574
+ score += 0.5
575
+ else:
576
+ score += 0.5 * max(0, 1 - (length - 120) / 80)
577
+
578
+ # Generic-ness: fuzzy penalty based on similarity to generic phrases
579
+ generic_phrases = [
580
+ "olá! como podemos ajudar",
581
+ "obrigado pela sua compra",
582
+ "seu pedido foi confirmado",
583
+ "agradecemos sua preferência",
584
+ ]
585
+ max_similarity = max(
586
+ _string_similarity(answer.lower(), g) for g in generic_phrases
587
+ )
588
+ score += 0.3 * (1 - max_similarity) # less generic = higher score
589
+
590
+ # Portuguese content: count PT-specific markers
591
+ pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua", answer, re.IGNORECASE)
592
+ score += min(0.2, 0.02 * len(pt_markers))
593
+
594
+ return min(score, 1.0)
595
+
596
+
597
+ def _string_similarity(a: str, b: str) -> float:
598
+ """Simple Jaccard-like similarity for short strings. 0.0 to 1.0."""
599
+ if not a or not b:
600
+ return 0.0
601
+ a_set = set(a.split())
602
+ b_set = set(b.split())
603
+ intersection = len(a_set & b_set)
604
+ union = len(a_set | b_set)
605
+ return intersection / union if union > 0 else 0.0
606
+
607
+
608
+ def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:
609
+ """Master reward function: dispatches by task type."""
610
+ rewards = []
611
+ for completion, prompt in zip(completions, prompts):
612
+ if isinstance(completion, list):
613
+ comp_text = completion[-1]["content"] if completion else ""
614
+ else:
615
+ comp_text = str(completion)
616
+
617
+ if isinstance(prompt, list):
618
+ prompt_text = " ".join(m.get("content", "") for m in prompt)
619
+ else:
620
+ prompt_text = str(prompt)
621
+
622
+ task = _classify_task_type(prompt_text)
623
+
624
+ if task == "extraction":
625
+ rewards.append(reward_extraction(comp_text))
626
+ elif task == "sql_qa":
627
+ rewards.append(reward_sql_qa(comp_text))
628
+ elif task == "insights":
629
+ rewards.append(reward_insights(comp_text))
630
+ elif task == "push":
631
+ rewards.append(reward_push(comp_text))
632
+ else:
633
+ r = 0.2 if has_think_block(comp_text) else 0.0
634
+ r += 0.3 if comp_text.strip() else 0.0
635
+ rewards.append(r)
636
+
637
+ return rewards
638
+
639
+
640
+ print("✓ Reward functions defined")
641
+ ```
642
+
643
+ ✓ Reward functions defined
644
+
645
+
646
+ ## Cell 7: Reward Calibration
647
+
648
+ **Gate:** Verify reward variance > 0 and most samples close `</think>`.
649
+
650
+
651
+ ```python
652
+ # Load dataset for calibration
653
+ train_path = DATA_DIR / "pairs" / "train.jsonl"
654
+
655
+ by_type = {"extraction": [], "sql_qa": [], "insights": [], "push": []}
656
+ with open(train_path) as f:
657
+ for line in f:
658
+ row = json.loads(line)
659
+ convs = row["conversations"]
660
+ prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
661
+ if not prompt_msgs:
662
+ continue
663
+ user_text = " ".join(m["content"] for m in prompt_msgs if m["role"] == "user")
664
+ task = _classify_task_type(user_text)
665
+ by_type[task].append(prompt_msgs)
666
+
667
+ print(f"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}")
668
+
669
+ # Pick 5 diverse samples for calibration
670
+ rng = random.Random(42)
671
+ cal_samples = []
672
+ for task_type in ["extraction", "sql_qa", "insights", "push", "sql_qa"]:
673
+ cal_samples.append(rng.choice(by_type[task_type]))
674
+
675
+ # Run calibration
676
+ FastLanguageModel.for_inference(model)
677
+ print("\nReward calibration (5 samples):")
678
+ print("-" * 60)
679
+
680
+ cal_rewards = []
681
+ for i, msgs in enumerate(cal_samples):
682
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
683
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
684
+ outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
685
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
686
+
687
+ r = commerce_reward_fn([response], [text])[0]
688
+ cal_rewards.append(r)
689
+ has_answer = "</think>" in response
690
+ answer_preview = strip_think(response)[:120] if has_answer else "[stuck in <think>]"
691
+ print(f" Sample {i+1}: reward={r:.2f} | closed_think={has_answer} | answer: {answer_preview}")
692
+
693
+ print(f"\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}")
694
+ print(f"Variance > 0: {len(set(cal_rewards)) > 1}")
695
+ ```
696
+
697
+ Prompts by type: extraction=659, sql_qa=655, insights=114, push=222
698
+
699
+ Reward calibration (5 samples):
700
+ ------------------------------------------------------------
701
+ Sample 1: reward=0.02 | closed_think=False | answer: [stuck in <think>]
702
+ Sample 2: reward=0.60 | closed_think=False | answer: [stuck in <think>]
703
+ Sample 3: reward=0.10 | closed_think=True | answer: Para determinarmos se deveríamos oferecer algum tipo de benefício adicional para tentar reverter essa decisão negativa d
704
+ Sample 4: reward=0.50 | closed_think=True | answer: Olha só![NomeDoCliente]! 😪 [Estado: SÃO PAULO]
705
+ PROBLEMA DETECTADO: Produto não recebido.
706
+ VALOR DO PEDIDO: R$ 138
707
+ Sample 5: reward=0.70 | closed_think=True | answer: Os clientes com baixos índices de rotação (low churn-risk) geralmente enfrentam menos problemas significativos comparado
708
+
709
+ Mean=0.38, Min=0.02, Max=0.70
710
+ Variance > 0: True
711
+
712
+
713
+ ## Cell 8: Dataset Preparation
714
+
715
+
716
+ ```python
717
+ from datasets import Dataset
718
+
719
+ def prepare_grpo_datasets(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO, seed=42):
720
+ """
721
+ Stratified split of by_type prompts into train and eval datasets.
722
+
723
+ For each task bucket: hold out `eval_ratio` fraction first, then sample
724
+ train targets from the remainder. Guarantees at least 1 eval sample per type.
725
+
726
+ Returns:
727
+ train_dataset: HF Dataset — passed to GRPOTrainer
728
+ eval_dataset: HF Dataset — consumed by EvalRewardCallback
729
+ """
730
+ rng = random.Random(seed)
731
+
732
+ # ── Step 1: per-task eval hold-out ────────────────────────────────────────
733
+ train_pools = {}
734
+ eval_records = []
735
+ for task, pool in by_type.items():
736
+ shuffled = pool.copy()
737
+ rng.shuffle(shuffled)
738
+ n_eval = max(1, int(len(shuffled) * eval_ratio))
739
+ eval_records.extend(shuffled[:n_eval])
740
+ train_pools[task] = shuffled[n_eval:]
741
+
742
+ # ── Step 2: stratified train sampling from remaining pool ─────────────────
743
+ targets = {
744
+ "extraction": int(n_prompts * 0.4),
745
+ "sql_qa": int(n_prompts * 0.4),
746
+ "insights": int(n_prompts * 0.1),
747
+ "push": int(n_prompts * 0.1),
748
+ }
749
+ train_records = []
750
+ for task, target_n in targets.items():
751
+ pool = train_pools[task]
752
+ n = min(target_n, len(pool))
753
+ train_records.extend(rng.sample(pool, n))
754
+ rng.shuffle(train_records)
755
+
756
+ print(f"Dataset split (eval_ratio={eval_ratio}):")
757
+ print(f" train : {len(train_records)} prompts")
758
+ print(f" eval : {len(eval_records)} prompts")
759
+ print(f" train dist: {', '.join(f'{k}={min(v, len(train_pools[k]))}' for k, v in targets.items())}")
760
+
761
+ train_ds = Dataset.from_list([{"prompt": msgs} for msgs in train_records])
762
+ eval_ds = Dataset.from_list([{"prompt": msgs} for msgs in eval_records])
763
+ return train_ds, eval_ds
764
+
765
+
766
+ train_dataset, eval_dataset = prepare_grpo_datasets()
767
+ dataset = train_dataset # backward compat — smoke/probe cells use `dataset`
768
+ print(f"\n✓ Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}")
769
+
770
+ ```
771
+
772
+ Dataset split (eval_ratio=0.15):
773
+ train : 300 prompts
774
+ eval : 246 prompts
775
+ train dist: extraction=120, sql_qa=120, insights=30, push=30
776
+
777
+ ✓ Datasets ready: train=300, eval=246
778
+
779
+
780
+ ## Cell 9: Smoke Test — Single Training Step
781
+
782
+ **Gate:** Runs 1 step without OOM. Reports step time for estimation.
783
+
784
+
785
+ ```python
786
+ from trl import GRPOConfig, GRPOTrainer
787
+
788
+ # Switch to training mode
789
+ FastLanguageModel.for_training(model)
790
+
791
+ smoke_config = GRPOConfig(
792
+ output_dir=str(CHECKPOINT_DIR / "smoke"),
793
+ num_generations=NUM_GENERATIONS,
794
+ scale_rewards=SCALE_REWARDS,
795
+ max_completion_length=MAX_COMPLETION_LENGTH,
796
+ max_steps=1,
797
+ num_train_epochs=1,
798
+ temperature=0.8, # was 0.1 from model defaults
799
+ per_device_train_batch_size=BATCH_SIZE,
800
+ gradient_accumulation_steps=1, # just 1 for smoke test (faster)
801
+ learning_rate=LEARNING_RATE,
802
+ fp16=False,
803
+ bf16=True,
804
+ logging_steps=1,
805
+ save_steps=999, # don't save during smoke
806
+ report_to="none",
807
+ max_prompt_length=MAX_SEQ_LENGTH // 2,
808
+ seed=42,
809
+ remove_unused_columns=False,
810
+ )
811
+
812
+ smoke_trainer = GRPOTrainer(
813
+ model=model,
814
+ reward_funcs=commerce_reward_fn,
815
+ args=smoke_config,
816
+ train_dataset=dataset,
817
+ tokenizer=tokenizer,
818
+ )
819
+
820
+ t0 = time.time()
821
+ smoke_trainer.train()
822
+ step_time = time.time() - t0
823
+
824
+ # Estimate full run: 75 steps with grad_accum=2
825
+ print(f"\n✓ Smoke test passed!")
826
+ print(f" Step time (grad_accum=1): {step_time:.0f}s")
827
+ print(f" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s")
828
+ print(f" Estimated full run (75 steps): {step_time * GRAD_ACCUM * 75 / 3600:.1f}h")
829
+
830
+ # Cleanup smoke test
831
+ del smoke_trainer
832
+ import gc; gc.collect(); torch.cuda.empty_cache()
833
+ ```
834
+
835
+ Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.
836
+ We will change the batch size of 4 to the `num_generations` of 8
837
+
838
+
839
+ ==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
840
+ \\ /| Num examples = 300 | Num Epochs = 1 | Total steps = 1
841
+ O^O/ \_/ \ Batch size per device = 8 | Gradient accumulation steps = 1
842
+ \ / Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
843
+ "-____-" Trainable parameters = 33,030,144 of 3,792,371,200 (0.87% trained)
844
+ `generation_config` default values have been modified to match model-specific defaults: {'max_length': 4096, 'repetition_penalty': 1.2, 'renormalize_logits': True}. If this is not desired, please set these values explicitly.
845
+
846
+ ✓ Smoke test passed!
847
+ Step time (grad_accum=1): 318s
848
+ Estimated step time (grad_accum=2): 636s
849
+ Estimated full run (75 steps): 13.2h
850
+
851
+
852
+ ## Cell 10: Probe Run (3 steps)
853
+
854
+ **Gate:** Loss > 0, rewards have variance, step time is consistent.
855
+
856
+
857
+ ```python
858
+ FastLanguageModel.for_training(model)
859
+
860
+ probe_config = GRPOConfig(
861
+ output_dir=str(CHECKPOINT_DIR / "probe"),
862
+ num_generations=NUM_GENERATIONS,
863
+ scale_rewards=SCALE_REWARDS,
864
+ max_completion_length=MAX_COMPLETION_LENGTH,
865
+ max_steps=3,
866
+ temperature=TEMPERATURE,
867
+ num_train_epochs=NUM_EPOCHS,
868
+ per_device_train_batch_size=BATCH_SIZE,
869
+ gradient_accumulation_steps=GRAD_ACCUM,
870
+ learning_rate=LEARNING_RATE,
871
+ warmup_ratio=0.1,
872
+ lr_scheduler_type="cosine",
873
+ fp16=False,
874
+ bf16=True,
875
+ logging_steps=1,
876
+ save_steps=999,
877
+ report_to="none",
878
+ max_prompt_length=MAX_SEQ_LENGTH // 2,
879
+ seed=42,
880
+ remove_unused_columns=False,
881
+ )
882
+
883
+ probe_trainer = GRPOTrainer(
884
+ model=model,
885
+ reward_funcs=commerce_reward_fn,
886
+ args=probe_config,
887
+ train_dataset=dataset,
888
+ tokenizer=tokenizer,
889
+ )
890
+
891
+ t0 = time.time()
892
+ result = probe_trainer.train()
893
+ elapsed = time.time() - t0
894
+
895
+ print(f"\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)")
896
+ print(f" Train loss: {result.training_loss:.4f}")
897
+ print(f" Estimated full run (75 steps): {elapsed/3 * 75 / 3600:.1f}h")
898
+
899
+ del probe_trainer
900
+ gc.collect(); torch.cuda.empty_cache()
901
+ ```
902
+
903
+ ==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
904
+ \\ /| Num examples = 300 | Num Epochs = 1 | Total steps = 3
905
+ O^O/ \_/ \ Batch size per device = 4 | Gradient accumulation steps = 2
906
+ \ / Data Parallel GPUs = 1 | Total batch size (4 x 2 x 1) = 8
907
+ "-____-" Trainable parameters = 33,030,144 of 3,792,371,200 (0.87% trained)
908
+
909
+
910
+ Unsloth: Will smartly offload gradients to save VRAM!
911
+
912
+ ✓ Probe complete in 665s (222s/step)
913
+ Train loss: 0.0062
914
+ Estimated full run (75 steps): 4.6h
915
+
916
+
917
+ ## Cell 11: Full Training Run
918
+
919
+ **ADR changes applied here:**
920
+ - `save_steps=15` (was 25) — checkpoint every ~3.3h on L4 Spot VM
921
+ - `save_total_limit=3` — auto-prune old checkpoints
922
+ - `logging_steps=1` (was 5) — every step visible in console + W&B
923
+ - `report_to="wandb"` (was "none") — full run tracked in W&B project `tucano2-commerce`
924
+ - `EvalRewardCallback` — custom callback running reward on held-out eval set every 10 steps
925
+ - **Early stopping**: halt if `mean_eval_reward` fails to improve ≥ 0.01 for 3 consecutive evals
926
+
927
+ **Resume:** Automatically resumes from latest checkpoint if interrupted.
928
+
929
+
930
+
931
+ ```python
932
+ # DEPRECATED_Cell: Fixed Safety Validation
933
+
934
+ import torch
935
+ from unsloth import FastLanguageModel
936
+
937
+ # Get the inner model
938
+ # was: policy_inner = getattr(trainer.model, 'module', trainer.model)
939
+ policy_inner = model
940
+
941
+ # ── CHECK 1: PASSED (no ref model) ──────────────────────────
942
+ print("CHECK 1: ✅ No reference model — biggest risk eliminated\n")
943
+
944
+ # ── CHECK 2: Weight drift — test LoRA A/B matrices (bf16) ───
945
+ print("=" * 60)
946
+ print("CHECK 2: LoRA adapter weight drift after merge/unmerge")
947
+
948
+ test_layer = None
949
+ for name, module in policy_inner.named_modules():
950
+ if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
951
+ test_layer = module
952
+ layer_name = name
953
+ break
954
+
955
+ if test_layer:
956
+ print(f"Testing: {layer_name}")
957
+
958
+ # Capture LoRA A and B matrices (these ARE float, not quantized)
959
+ lora_a_before = list(test_layer.lora_A.values())[0].weight.clone().detach()
960
+ lora_b_before = list(test_layer.lora_B.values())[0].weight.clone().detach()
961
+
962
+ # Also capture the base weight bytes (NF4 — just check identity)
963
+ base_before = test_layer.base_layer.weight.data.clone()
964
+
965
+ print(f" LoRA A dtype: {lora_a_before.dtype}, shape: {lora_a_before.shape}")
966
+ print(f" LoRA B dtype: {lora_b_before.dtype}, shape: {lora_b_before.shape}")
967
+ print(f" Base weight dtype: {test_layer.base_layer.weight.dtype}")
968
+
969
+ # Run 50 merge/unmerge cycles (simulate 50 GRPO steps)
970
+ for i in range(50):
971
+ FastLanguageModel.for_inference(policy_inner)
972
+ FastLanguageModel.for_training(policy_inner)
973
+
974
+ lora_a_after = list(test_layer.lora_A.values())[0].weight.clone().detach()
975
+ lora_b_after = list(test_layer.lora_B.values())[0].weight.clone().detach()
976
+ base_after = test_layer.base_layer.weight.data.clone()
977
+
978
+ # LoRA weight drift
979
+ a_diff = (lora_a_before - lora_a_after).abs().max().item()
980
+ b_diff = (lora_b_before - lora_b_after).abs().max().item()
981
+ a_rel = a_diff / (lora_a_before.abs().mean().item() + 1e-8)
982
+ b_rel = b_diff / (lora_b_before.abs().mean().item() + 1e-8)
983
+
984
+ print(f"\n After 50 cycles:")
985
+ print(f" LoRA A max diff: {a_diff:.2e} (relative: {a_rel:.2e})")
986
+ print(f" LoRA B max diff: {b_diff:.2e} (relative: {b_rel:.2e})")
987
+
988
+ # Base weight byte-level identity
989
+ base_identical = torch.equal(base_before, base_after)
990
+ print(f" Base weights identical (byte-exact): {base_identical}")
991
+
992
+ if a_rel < 1e-5 and b_rel < 1e-5 and base_identical:
993
+ print(" ✅ PASS: No drift after 50 cycles")
994
+ elif a_diff == 0 and b_diff == 0 and base_identical:
995
+ print(" ✅ PASS: Bit-perfect — for_inference() does NOT merge weights")
996
+ else:
997
+ print(" ❌ FAIL: Weight drift detected")
998
+
999
+ # ── CHECK 3: Memory leak over 20 cycles ─────────────────────
1000
+ print("\n" + "=" * 60)
1001
+ print("CHECK 3: Memory leak test (20 cycles)")
1002
+
1003
+ torch.cuda.empty_cache()
1004
+ import gc; gc.collect()
1005
+
1006
+ baseline_mem = torch.cuda.memory_allocated() / 1e9
1007
+ print(f" Baseline: {baseline_mem:.3f} GB")
1008
+
1009
+ for i in range(20):
1010
+ FastLanguageModel.for_inference(policy_inner)
1011
+ # Simulate a short generation
1012
+ test_input = torch.tensor([[1, 2, 3]], device="cuda")
1013
+ with torch.no_grad():
1014
+ _ = policy_inner(test_input)
1015
+ FastLanguageModel.for_training(policy_inner)
1016
+
1017
+ if (i + 1) % 5 == 0:
1018
+ gc.collect(); torch.cuda.empty_cache()
1019
+ current = torch.cuda.memory_allocated() / 1e9
1020
+ delta = current - baseline_mem
1021
+ print(f" Cycle {i+1:2d}: {current:.3f} GB (delta: {delta:+.3f} GB)")
1022
+
1023
+ final_mem = torch.cuda.memory_allocated() / 1e9
1024
+ total_drift = final_mem - baseline_mem
1025
+ print(f"\n Total memory drift: {total_drift:+.3f} GB")
1026
+ if abs(total_drift) < 0.1:
1027
+ print(" ✅ PASS: No significant memory leak")
1028
+ else:
1029
+ print(" ⚠️ WARN: Memory growing — potential leak")
1030
+
1031
+ # ── CHECK 4: Gradient flow after mode switch ─────────────────
1032
+ print("\n" + "=" * 60)
1033
+ print("CHECK 4: Gradient flow survives mode switching")
1034
+
1035
+ FastLanguageModel.for_inference(policy_inner)
1036
+ FastLanguageModel.for_training(policy_inner)
1037
+
1038
+ # Check LoRA params still require grad
1039
+ trainable = 0
1040
+ frozen = 0
1041
+ for name, p in policy_inner.named_parameters():
1042
+ if 'lora_' in name:
1043
+ if p.requires_grad:
1044
+ trainable += 1
1045
+ else:
1046
+ frozen += 1
1047
+
1048
+ print(f" LoRA params requiring grad: {trainable}")
1049
+ print(f" LoRA params frozen (bad): {frozen}")
1050
+ if frozen == 0 and trainable > 0:
1051
+ print(" ✅ PASS: All LoRA params trainable after mode switch")
1052
+ else:
1053
+ print(" ❌ FAIL: Mode switch froze LoRA parameters")
1054
+ ```
1055
+
1056
+
1057
+ ```python
1058
+ # ── W&B Auth ──────────────────────────────────────────────────────────────────
1059
+ # Vertex AI Workbench does not have wandb pre-authenticated.
1060
+ # WANDB_API_KEY must be set as an environment variable before running this cell.
1061
+ # On Vertex AI: add it to the instance env vars, or set it in a prior cell:
1062
+ import os
1063
+ import wandb
1064
+
1065
+ os.environ["WANDB_API_KEY"] = "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"
1066
+
1067
+ # ── W&B Auth ──────────────────────────────────────────────────────────────────
1068
+ _wandb_key = os.environ.get("WANDB_API_KEY", "").strip()
1069
+ if not _wandb_key:
1070
+ raise EnvironmentError(
1071
+ "WANDB_API_KEY is not set.\n"
1072
+ "Set it before running this cell:\n"
1073
+ " import os; os.environ[\"WANDB_API_KEY\"] = \"your-key\"\n"
1074
+ "Or add it to your Vertex AI Workbench instance environment variables."
1075
+ )
1076
+ wandb.login(key=_wandb_key, relogin=True)
1077
+ print(f"✓ W&B authenticated")
1078
+ ```
1079
+
1080
+
1081
+ ```python
1082
+ import shutil
1083
+ import torch
1084
+ from transformers import TrainerCallback
1085
+ from trl import GRPOConfig, GRPOTrainer
1086
+
1087
+ # ── W&B Init ──────────────────────────────────────────────────────────────────
1088
+ wandb.init(
1089
+ project=WANDB_PROJECT,
1090
+ name=f"grpo-v2-l4-{time.strftime('%Y%m%d-%H%M')}",
1091
+ config={
1092
+ "model_id": MODEL_ID,
1093
+ "version": "v2",
1094
+ "save_steps": SAVE_STEPS,
1095
+ "eval_steps": EVAL_STEPS,
1096
+ "eval_max_samples": EVAL_MAX_SAMPLES,
1097
+ "eval_max_tokens": EVAL_MAX_TOKENS,
1098
+ "patience": EARLY_STOPPING_PATIENCE,
1099
+ "delta": EARLY_STOPPING_DELTA,
1100
+ "batch_size": BATCH_SIZE,
1101
+ "grad_accum": GRAD_ACCUM,
1102
+ "max_steps": MAX_STEPS,
1103
+ "learning_rate": LEARNING_RATE,
1104
+ "num_generations": NUM_GENERATIONS,
1105
+ "scale_rewards": SCALE_REWARDS,
1106
+ "eval_split_ratio": EVAL_SPLIT_RATIO,
1107
+ "train_prompts": len(train_dataset),
1108
+ "eval_prompts": len(eval_dataset),
1109
+ "use_vllm": USE_VLLM,
1110
+ },
1111
+ )
1112
+ print(f"✓ W&B run: {wandb.run.url}")
1113
+
1114
+ # ── Resume logic ──────────────────────────────────────────────────────────────
1115
+ FRESH = True # Set True to clear old checkpoints and start over
1116
+
1117
+ resume_from = None
1118
+ if FRESH and CHECKPOINT_DIR.exists():
1119
+ print("FRESH: deleting old checkpoints...")
1120
+ shutil.rmtree(CHECKPOINT_DIR)
1121
+ elif CHECKPOINT_DIR.exists():
1122
+ checkpoints = sorted(
1123
+ [d for d in CHECKPOINT_DIR.iterdir()
1124
+ if d.is_dir() and d.name.startswith("checkpoint-")],
1125
+ key=lambda d: int(d.name.split("-")[-1]),
1126
+ )
1127
+ if checkpoints:
1128
+ resume_from = str(checkpoints[-1])
1129
+ print(f"Resuming from: {resume_from}")
1130
+
1131
+ # ── UnslothGRPOTrainer: activate inference kernels during generation ───────────
1132
+ class UnslothGRPOTrainer(GRPOTrainer):
1133
+ """
1134
+ Wraps GRPOTrainer._generate() with Unsloth for_inference()/for_training()
1135
+ to activate Unsloth's optimized Triton kernels during the generation phase.
1136
+
1137
+ Root cause of ~3.4 tok/s on L4:
1138
+ - GRPOTrainer calls unwrapped_model.generate() while model is in train() mode.
1139
+ - Unsloth's fused inference Triton kernels are only active after for_inference().
1140
+ - Without this wrapper: ~3-4 tok/s. With it: expected ~8-15 tok/s on L4.
1141
+
1142
+ Override target: _generate() — called from _generate_and_score_completions().
1143
+ Verified against TRL 0.24.0 source (asserted in Cell 3).
1144
+ """
1145
+ def _generate(self, prompts, images):
1146
+ FastLanguageModel.for_inference(self.model)
1147
+ try:
1148
+ result = super()._generate(prompts, images)
1149
+ finally:
1150
+ # Always restore — even if generation crashes
1151
+ FastLanguageModel.for_training(self.model)
1152
+ return result
1153
+
1154
+ # ── EvalRewardCallback — capped for L4 feasibility ───────────────────────────
1155
+ class EvalRewardCallback(TrainerCallback):
1156
+ """
1157
+ Runs commerce_reward_fn on a capped subset of the held-out eval set.
1158
+
1159
+ v2 changes:
1160
+ - Capped to EVAL_MAX_SAMPLES (10) — full 45 samples × 591s = 7.4h/eval.
1161
+ - max_new_tokens=EVAL_MAX_TOKENS (256) — keeps each eval pass < 15min on L4.
1162
+ - Logs eval/mean_reward to both console and W&B.
1163
+ - Patience-based early stopping via control.should_training_stop.
1164
+ """
1165
+ def __init__(self, eval_records, reward_fn, patience=3, delta=0.01):
1166
+ self.eval_records = eval_records
1167
+ self.reward_fn = reward_fn
1168
+ self.patience = patience
1169
+ self.delta = delta
1170
+ self.best_reward = -float("inf")
1171
+ self.no_improve_count = 0
1172
+
1173
+ def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
1174
+ # ^^^ Changed: tokenizer -> processing_class
1175
+
1176
+ if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:
1177
+ return control
1178
+
1179
+ # processing_class is the tokenizer in TRL 0.24.0+
1180
+ tokenizer = processing_class
1181
+ if tokenizer is None:
1182
+ print("[EvalRewardCallback] WARNING: tokenizer is None, skipping eval")
1183
+ return control
1184
+
1185
+ mean_reward = self._run_eval(model, tokenizer, args)
1186
+ improved = mean_reward > self.best_reward + self.delta
1187
+ status = (
1188
+ "↑ improved" if improved
1189
+ else f"↔ no gain ({self.no_improve_count + 1}/{self.patience})"
1190
+ )
1191
+
1192
+ # Log to W&B
1193
+ wandb.log(
1194
+ {
1195
+ "eval/mean_reward": mean_reward,
1196
+ "eval/best_reward": max(self.best_reward, mean_reward),
1197
+ "eval/no_improve_count": self.no_improve_count,
1198
+ },
1199
+ step=state.global_step,
1200
+ )
1201
+
1202
+ print(
1203
+ f"\n[EvalReward] step={state.global_step} | "
1204
+ f"mean_eval_reward={mean_reward:.4f} | best={self.best_reward:.4f} | {status}"
1205
+ )
1206
+
1207
+ if improved:
1208
+ self.best_reward = mean_reward
1209
+ self.no_improve_count = 0
1210
+ else:
1211
+ self.no_improve_count += 1
1212
+ if self.no_improve_count >= self.patience:
1213
+ print(
1214
+ f"[EarlyStopping] No improvement ≥ {self.delta} for "
1215
+ f"{self.patience} consecutive evals. Halting training."
1216
+ )
1217
+ wandb.log({"early_stop/step": state.global_step}, step=state.global_step)
1218
+ control.should_training_stop = True
1219
+
1220
+ return control
1221
+
1222
+ def _run_eval(self, model, tokenizer, args):
1223
+ """One greedy completion per eval prompt, scored by reward_fn."""
1224
+ FastLanguageModel.for_inference(model)
1225
+ rewards = []
1226
+
1227
+ # Cap to subset
1228
+ subset = self.eval_records[:EVAL_MAX_SAMPLES]
1229
+
1230
+ for record in subset:
1231
+ msgs = record["prompt"]
1232
+ text = tokenizer.apply_chat_template(
1233
+ msgs, tokenize=False, add_generation_prompt=True
1234
+ )
1235
+ inputs = tokenizer(
1236
+ text, return_tensors="pt", truncation=True,
1237
+ max_length=args.max_prompt_length,
1238
+ ).to(model.device)
1239
+ with torch.no_grad():
1240
+ out = model.generate(
1241
+ **inputs,
1242
+ max_new_tokens=EVAL_MAX_TOKENS, # use the cap
1243
+ temperature=0.7,
1244
+ do_sample=True,
1245
+ )
1246
+ resp = tokenizer.decode(
1247
+ out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
1248
+ )
1249
+ rewards.append(self.reward_fn([resp], [text])[0])
1250
+
1251
+ FastLanguageModel.for_training(model)
1252
+ return sum(rewards) / len(rewards) if rewards else 0.0
1253
+
1254
+ # ── Training Config ────────────────────────────────────────────────────────────
1255
+ FastLanguageModel.for_training(model)
1256
+
1257
+ grpo_config = GRPOConfig(
1258
+ output_dir=str(CHECKPOINT_DIR),
1259
+ num_generations=NUM_GENERATIONS,
1260
+ scale_rewards=SCALE_REWARDS,
1261
+ max_completion_length=MAX_COMPLETION_LENGTH,
1262
+ max_steps=MAX_STEPS,
1263
+ temperature=TEMPERATURE,
1264
+ num_train_epochs=NUM_EPOCHS,
1265
+ per_device_train_batch_size=BATCH_SIZE,
1266
+ gradient_accumulation_steps=GRAD_ACCUM,
1267
+ learning_rate=LEARNING_RATE,
1268
+ warmup_ratio=0.1,
1269
+ lr_scheduler_type="cosine",
1270
+ fp16=False,
1271
+ bf16=True,
1272
+ logging_steps=1, # every step in console + W&B
1273
+ save_steps=SAVE_STEPS, # 5 — ~3.3h exposure on L4 Spot VM
1274
+ save_total_limit=SAVE_TOTAL_LIMIT, # prune old checkpoints, keep 3
1275
+ save_only_model=True, # eliminate save overhead
1276
+ eval_steps=EVAL_STEPS, # drives EvalRewardCallback cadence
1277
+ report_to="wandb",
1278
+ max_prompt_length=MAX_SEQ_LENGTH // 2,
1279
+ seed=42,
1280
+ remove_unused_columns=False,
1281
+ # vLLM colocate — only used when USE_VLLM=True (Trainer class also switches below)
1282
+ **({"use_vllm": True, "vllm_mode": "colocate",
1283
+ "vllm_enable_sleep_mode": True} if USE_VLLM else {}),
1284
+ )
1285
+
1286
+ eval_cb = EvalRewardCallback(
1287
+ eval_records=list(eval_dataset),
1288
+ reward_fn=commerce_reward_fn,
1289
+ patience=EARLY_STOPPING_PATIENCE,
1290
+ delta=EARLY_STOPPING_DELTA,
1291
+ )
1292
+
1293
+ # ── Trainer: UnslothGRPOTrainer (vLLM handles its own generation path) ────────
1294
+ TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer
1295
+ trainer = TrainerClass(
1296
+ model=model,
1297
+ reward_funcs=commerce_reward_fn,
1298
+ args=grpo_config,
1299
+ train_dataset=train_dataset,
1300
+ processing_class=tokenizer, # v2 fix: was tokenizer=tokenizer (silently dropped)
1301
+ callbacks=[eval_cb],
1302
+ )
1303
+ print(
1304
+ f"Trainer: {TrainerClass.__name__} | "
1305
+ f"max_steps={MAX_STEPS} | save_every={SAVE_STEPS} | eval_every={EVAL_STEPS} | "
1306
+ f"eval_cap={EVAL_MAX_SAMPLES}×{EVAL_MAX_TOKENS}tok | resume={resume_from is not None}"
1307
+ )
1308
+
1309
+ t_start = time.time()
1310
+ result = trainer.train(resume_from_checkpoint=resume_from)
1311
+ elapsed = time.time() - t_start
1312
+
1313
+ wandb.log({
1314
+ "train/final_loss": result.training_loss,
1315
+ "train/duration_hours": elapsed / 3600,
1316
+ "train/total_steps": result.global_step,
1317
+ "eval/best_reward_final": eval_cb.best_reward,
1318
+ })
1319
+ wandb.finish()
1320
+
1321
+ print(f"\n{'='*60}")
1322
+ print(f"GRPO v2 Training Complete")
1323
+ print(f" Loss: {result.training_loss:.4f}")
1324
+ print(f" Steps: {result.global_step}")
1325
+ print(f" Duration: {elapsed/3600:.1f}h")
1326
+ print(f" Best eval R: {eval_cb.best_reward:.4f}")
1327
+ print(f" Trainer: {TrainerClass.__name__}")
1328
+ print(f"{'='*60}")
1329
+
1330
+ ```
1331
+
1332
+ wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly.
1333
+ wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.
1334
+ wandb: WARNING [wandb.login()] Changing session credentials to explicit value for https://api.wandb.ai.
1335
+ wandb: Appending key for api.wandb.ai to your netrc file: /home/jupyter/.netrc
1336
+
1337
+
1338
+ ✓ W&B authenticated
1339
+
1340
+ Tracking run with wandb version 0.26.0
1341
+
1342
+ Run data is saved locally in <code>/home/jupyter/tucano2/notebooks/wandb/run-20260422_212656-2m114rh7</code>
1343
+
1344
+
1345
+
1346
+ Syncing run <strong><a href='https://wandb.ai/tferrazrafael-self/tucano2-commerce/runs/2m114rh7' target="_blank">grpo-v2-l4-20260422-2126</a></strong> to <a href='https://wandb.ai/tferrazrafael-self/tucano2-commerce' target="_blank">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target="_blank">docs</a>)<br>
1347
+
1348
+
1349
+
1350
+ View project at <a href='https://wandb.ai/tferrazrafael-self/tucano2-commerce' target="_blank">https://wandb.ai/tferrazrafael-self/tucano2-commerce</a>
1351
+
1352
+
1353
+
1354
+ View run at <a href='https://wandb.ai/tferrazrafael-self/tucano2-commerce/runs/2m114rh7' target="_blank">https://wandb.ai/tferrazrafael-self/tucano2-commerce/runs/2m114rh7</a>
1355
+
1356
+
1357
+ ✓ W&B run: https://wandb.ai/tferrazrafael-self/tucano2-commerce/runs/2m114rh7
1358
+ FRESH: deleting old checkpoints...
1359
+ Trainer: UnslothGRPOTrainer | max_steps=300 | save_every=5 | eval_every=10 | eval_cap=5×2048tok | resume=False
1360
+
1361
+
1362
+ ==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
1363
+ \\ /| Num examples = 300 | Num Epochs = 1 | Total steps = 300
1364
+ O^O/ \_/ \ Batch size per device = 4 | Gradient accumulation steps = 2
1365
+ \ / Data Parallel GPUs = 1 | Total batch size (4 x 2 x 1) = 8
1366
+ "-____-" Trainable parameters = 33,030,144 of 3,792,371,200 (0.87% trained)
1367
+
1368
+
1369
+ Unsloth: Will smartly offload gradients to save VRAM!
1370
+
1371
+ [EvalReward] step=10 | mean_eval_reward=0.0830 | best=-inf | ↑ improved
1372
+
1373
+ [EvalReward] step=20 | mean_eval_reward=0.0830 | best=0.0830 | ↔ no gain (1/10)
1374
+
1375
+ [EvalReward] step=30 | mean_eval_reward=0.1050 | best=0.0830 | ↑ improved
1376
+
1377
+ [EvalReward] step=40 | mean_eval_reward=0.1010 | best=0.1050 | ↔ no gain (1/10)
1378
+
1379
+
1380
+ wandb: WARNING Tried to log to step 40 that is less than the current step 239. Steps must be monotonically increasing, so this data will be ignored. See https://wandb.me/define-metric to log data out of order.
1381
+
1382
+
1383
+ ## Cell 12: Save Adapter
1384
+
1385
+
1386
+ ```python
1387
+ GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
1388
+ model.save_pretrained(str(GRPO_ADAPTER_DIR))
1389
+ tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))
1390
+
1391
+ # Save training summary
1392
+ summary = {
1393
+ "model_id": MODEL_ID,
1394
+ "sft_adapter": str(SFT_ADAPTER_DIR),
1395
+ "method": "GRPO",
1396
+ "train_loss": result.training_loss,
1397
+ "num_prompts": len(dataset),
1398
+ "num_generations": NUM_GENERATIONS,
1399
+ "scale_rewards": SCALE_REWARDS,
1400
+ "learning_rate": LEARNING_RATE,
1401
+ "epochs": NUM_EPOCHS,
1402
+ "max_seq_length": MAX_SEQ_LENGTH,
1403
+ "max_completion_length": MAX_COMPLETION_LENGTH,
1404
+ "duration_seconds": round(elapsed),
1405
+ "gpu": "L4",
1406
+ "platform": "vertex-ai-workbench",
1407
+ }
1408
+ with open(GRPO_ADAPTER_DIR / "training_summary.json", "w") as f:
1409
+ json.dump(summary, f, indent=2)
1410
+
1411
+ print(f"✓ Adapter saved to {GRPO_ADAPTER_DIR}")
1412
+ print(f" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir()]}")
1413
+ ```
1414
+
1415
+ ✓ Adapter saved to /home/jupyter/tucano2/models/tucano2-commerce-grpo
1416
+ Files: ['tokenizer_config.json', 'tokenizer.json', 'README.md', 'adapter_model.safetensors', 'chat_template.jinja', 'training_summary.json', 'checkpoints', 'special_tokens_map.json', 'adapter_config.json']
1417
+
1418
+
1419
+ ## Cell 13: Validation
1420
+
1421
+ Generate 5 samples with trained model, score with reward functions.
1422
+
1423
+
1424
+ ```python
1425
+ FastLanguageModel.for_inference(model)
1426
+
1427
+ system_msg = {"role": "system", "content": SYSTEM_PT}
1428
+
1429
+ test_prompts = [
1430
+ {"role": "user", "content": (
1431
+ "Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
1432
+ "nota=2/5 | status=delivered\ntítulo: decepcionado\n"
1433
+ "texto: Produto veio com defeito e o vendedor não respondeu.\n\n"
1434
+ "Retorne um objeto JSON com exatamente estas chaves:\n"
1435
+ "sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
1436
+ "seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
1437
+ )},
1438
+ {"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
1439
+ {"role": "user", "content": "Analise a retenção de clientes afetados por product_quality."},
1440
+ {"role": "user", "content": (
1441
+ "Perfil do cliente:\n- Estado: MG\n- Valor do pedido: R$150\n"
1442
+ "- Reclamação: produto com defeito\n- Nota: 1.0/5\n\n"
1443
+ "Este cliente deve receber uma notificação de reengajamento?"
1444
+ )},
1445
+ {"role": "user", "content": "Compare a satisfação de clientes em SP vs RJ."},
1446
+ ]
1447
+
1448
+ print("=== GRPO Validation ===")
1449
+ print()
1450
+
1451
+ for i, prompt in enumerate(test_prompts):
1452
+ messages = [system_msg, prompt]
1453
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1454
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
1455
+
1456
+ outputs = model.generate(**inputs, max_new_tokens=2048, temperature=0.1, do_sample=True)
1457
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1458
+
1459
+ reward = commerce_reward_fn([response], [text])[0]
1460
+ answer = strip_think(response)
1461
+
1462
+ print(f"--- Sample {i+1} (reward={reward:.2f}, tokens={len(response.split())}) ---")
1463
+ print(f"Prompt: {prompt['content'][:80]}...")
1464
+ print(f"Answer: {answer[:300]}")
1465
+ print()
1466
+ ```
1467
+
1468
+ === GRPO Validation ===
1469
+
1470
+ --- Sample 1 (reward=0.12, tokens=442) ---
1471
+ Prompt: Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.
1472
+
1473
+ n...
1474
+ Answer: ```json
1475
+ {
1476
+ "sentiment": "negativo",
1477
+ "sentiment_score": -0.8,
1478
+ "churn_risk": -1,
1479
+ "delivery_issue": true,
1480
+ "product_issue": false,
1481
+ "seller_issue": false,
1482
+ "main_complaint": "falha no atendimento post-venda",
1483
+ "complaint_category": "serviço pós-venda",
1484
+ "repeat_intent": 0,
1485
+ "would_recommen
1486
+
1487
+ --- Sample 2 (reward=0.70, tokens=1451) ---
1488
+ Prompt: Quais são as categorias de reclamação mais frequentes e como afetam a nota média...
1489
+ Answer: As categorias de reclamação variam bastante conforme o tipo de serviço ou produto sendo consumido, mas existem alguns tipos comuns que tendem a aparecer frequentemente. Aqui estão algumas das principais categorias de reclamações e suas possíveis influências na nota média:
1490
+
1491
+ ### Categorias Frequentes
1492
+
1493
+ --- Sample 3 (reward=0.70, tokens=1465) ---
1494
+ Prompt: Analise a retenção de clientes afetados por product_quality....
1495
+ Answer: Claro! Vamos analisar a possível influência da *quality* do produto (*product_quality*) sobre a taxa de retenção de clientes. Para isso, vamos seguir alguns passos lógicos:
1496
+
1497
+ ### Passo 1: Definição das Variáveis
1498
+ - **Quality do Produto (`product_quality`)**: Uma métrica quantitativa ou qualitativa que
1499
+
1500
+ --- Sample 4 (reward=0.50, tokens=994) ---
1501
+ Prompt: Perfil do cliente:
1502
+ - Estado: MG
1503
+ - Valor do pedido: R$150
1504
+ - Reclamação: produto c...
1505
+ Answer: <think>
1506
+ O usuário está me perguntando se este cliente deveria receber uma notificação de reengajamento baseado nas informações fornecidas. Primeiro, vou analisar cada ponto individualmente para entender a situação completa.
1507
+
1508
+ 1. **Estado:** O estado mencionado pelo cliente é Minas Gerais (MG). Isso
1509
+
1510
+ --- Sample 5 (reward=0.70, tokens=1396) ---
1511
+ Prompt: Compare a satisfação de clientes em SP vs RJ....
1512
+ Answer: Para compararmos a satisfação dos clientes em São Paulo (SP) versus Rio de Janeiro (RJ), precisamos considerar alguns pontos:
1513
+
1514
+ 1. **Economia Local**: A renda média per capita varia significativamente entre estas duas regiões metropolitanas. No geral, São Paulo tende a ter uma renda maior comparada a
1515
+
1516
+
1517
+
1518
+
1519
+ ```python
1520
+
1521
+ ```