Delete grpo_vertex_v3.md
Browse filesWe're gonna use .ipynb version only
- grpo_vertex_v3.md +0 -1322
grpo_vertex_v3.md
DELETED
|
@@ -1,1322 +0,0 @@
|
|
| 1 |
-
# Tucano2 Commerce — GRPO Training v3 (Vertex AI Workbench / L4)
|
| 2 |
-
|
| 3 |
-
**v3 changes over v2 — grounded in published research:**
|
| 4 |
-
|
| 5 |
-
| Change | v2 Value | v3 Value | Paper Reference |
|
| 6 |
-
|--------|----------|----------|----------------|
|
| 7 |
-
| Temperature | 0.8 | **1.0** | Skywork-OR1 (2505.22312) §4: τ=1.0 gives 5-8% better results, delays entropy collapse |
|
| 8 |
-
| Completion length | 2048 | **4096** | Dr. GRPO (2503.20783) §3.1: length bias inflates wrong answers → ceiling hit blocks learning |
|
| 9 |
-
| Num generations | 8 | **4** | VRAM tradeoff: 4×4096 ≈ 8×2048. MC-GRPO (2601.22582): G=4 works with noise mitigation |
|
| 10 |
-
| Learning rate | 5e-7 | **2e-6** | Dr. GRPO Appendix G: LR=1e-6; Reasoning-SQL: LR=1e-6. v2 clip_ratio=0 → room to push 2-4× |
|
| 11 |
-
| β (KL penalty) | implicit | **0.0** | Dr. GRPO §3.2: β=0 optimal for rule-based rewards |
|
| 12 |
-
| Training data | 300 | **ALL (~1400)** | Skywork-OR1 §3.1: small prompt sets → model memorizes → entropy collapse |
|
| 13 |
-
| Reward functions | single composite | **staged (format→partial→task)** | Reasoning-SQL (2503.23157) §3.2: format rewards converge first, enable task learning |
|
| 14 |
-
| Zero-advantage groups | included | **filtered with noise injection** | Skywork-OR1 §3.1: zero-std groups destabilize training |
|
| 15 |
-
| Entropy monitoring | none | **EntropyMonitorCallback** | Skywork-OR1 §4: early detection prevents collapse |
|
| 16 |
-
| Early stopping patience | 10 | **15** | More runway for longer completions |
|
| 17 |
-
| Save total limit | 3 | **5** | Keep more checkpoints — v2 lost the best one |
|
| 18 |
-
| Eval temperature | 0.7 | **0.1** | Deterministic eval = less noisy signal |
|
| 19 |
-
| General reasoning mix | none | **30% (optional)** | Cocktail Effect (2410.01109): multi-task mix boosts domain performance 2-15% |
|
| 20 |
-
|
| 21 |
-
**Prerequisites:**
|
| 22 |
-
- Upload `data/pairs/train.jsonl` (2.1 MB) to `./data/pairs/`
|
| 23 |
-
- Upload `models/tucano2-commerce-sft/` (126 MB) to `./models/tucano2-commerce-sft/`
|
| 24 |
-
- **NEW:** Optional `data/pairs/general_reasoning.jsonl` for 30% general data mix
|
| 25 |
-
|
| 26 |
-
**Hardware:** L4 (24GB), PyTorch kernel, bf16 supported
|
| 27 |
-
|
| 28 |
-
---
|
| 29 |
-
|
| 30 |
-
## Cell 1: Dependencies
|
| 31 |
-
|
| 32 |
-
Restart your kernel first (Kernel → Restart), then run these cells in order, one at a time:
|
| 33 |
-
|
| 34 |
-
```python
|
| 35 |
-
# Cell 1a — Nuke everything ML-related
|
| 36 |
-
!pip uninstall -y torch torchvision torchaudio \
|
| 37 |
-
unsloth unsloth-zoo \
|
| 38 |
-
trl transformers peft accelerate \
|
| 39 |
-
bitsandbytes vllm vllm-flash-attn \
|
| 40 |
-
datasets tokenizers safetensors huggingface-hub \
|
| 41 |
-
wandb xformers triton \
|
| 42 |
-
cuda-bindings cuda-python \
|
| 43 |
-
sentencepiece protobuf \
|
| 44 |
-
2>/dev/null
|
| 45 |
-
```
|
| 46 |
-
|
| 47 |
-
```python
|
| 48 |
-
# Cell 1b — Kill any stragglers
|
| 49 |
-
!pip freeze | grep -iE "torch|unsloth|trl|vllm|bitsandbytes|transformers|peft|accelerate" | xargs pip uninstall -y 2>/dev/null
|
| 50 |
-
```
|
| 51 |
-
|
| 52 |
-
```python
|
| 53 |
-
# Cell 1c — Purge cache
|
| 54 |
-
!pip cache purge
|
| 55 |
-
```
|
| 56 |
-
|
| 57 |
-
**⚠️ Restart kernel again**, then:
|
| 58 |
-
|
| 59 |
-
```python
|
| 60 |
-
# Cell 1d — Clean install, correct order
|
| 61 |
-
!pip install "unsloth"
|
| 62 |
-
```
|
| 63 |
-
|
| 64 |
-
```python
|
| 65 |
-
# Cell 1e — Pin TRL (Unsloth may pull a different version)
|
| 66 |
-
!pip install "trl==0.24.0" --no-deps
|
| 67 |
-
```
|
| 68 |
-
|
| 69 |
-
```python
|
| 70 |
-
# Cell 1f — Extra deps
|
| 71 |
-
!pip install "rich" "wandb"
|
| 72 |
-
```
|
| 73 |
-
|
| 74 |
-
---
|
| 75 |
-
|
| 76 |
-
## Cell 2: Hello World — GPU + Unsloth Verification
|
| 77 |
-
|
| 78 |
-
```python
|
| 79 |
-
import torch
|
| 80 |
-
|
| 81 |
-
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 82 |
-
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 83 |
-
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 84 |
-
print(f"bf16 support: {torch.cuda.is_bf16_supported()}")
|
| 85 |
-
|
| 86 |
-
from unsloth import FastLanguageModel
|
| 87 |
-
print("\n✓ Unsloth loaded successfully")
|
| 88 |
-
|
| 89 |
-
import trl
|
| 90 |
-
print(f"✓ TRL version: {trl.__version__}")
|
| 91 |
-
|
| 92 |
-
import transformers
|
| 93 |
-
print(f"✓ Transformers version: {transformers.__version__}")
|
| 94 |
-
```
|
| 95 |
-
|
| 96 |
-
---
|
| 97 |
-
|
| 98 |
-
## Cell 3: Config + Constants
|
| 99 |
-
|
| 100 |
-
```python
|
| 101 |
-
import os
|
| 102 |
-
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
|
| 103 |
-
|
| 104 |
-
import json
|
| 105 |
-
import re
|
| 106 |
-
import time
|
| 107 |
-
import random
|
| 108 |
-
import gc
|
| 109 |
-
from pathlib import Path
|
| 110 |
-
|
| 111 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 112 |
-
# v3 CONFIG — Every change is annotated with paper reference
|
| 113 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 114 |
-
|
| 115 |
-
MODEL_ID = "Polygl0t/Tucano2-qwen-3.7B-Think"
|
| 116 |
-
MAX_SEQ_LENGTH = 8192 # v3: increased from 4096 — model supports 32k, we need room for 4096 completion + prompt
|
| 117 |
-
|
| 118 |
-
# ── Paths ─────────────────────────────────────────────────────────────────────
|
| 119 |
-
DATA_DIR = Path("/home/jupyter/tucano2/data")
|
| 120 |
-
MODELS_DIR = Path("/home/jupyter/tucano2/models")
|
| 121 |
-
SFT_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-sft"
|
| 122 |
-
GRPO_ADAPTER_DIR = MODELS_DIR / "tucano2-commerce-grpo-v3" # v3: separate dir from v2
|
| 123 |
-
CHECKPOINT_DIR = GRPO_ADAPTER_DIR / "checkpoints"
|
| 124 |
-
|
| 125 |
-
# ── Training data ─────────────────────────────────────────────────────────────
|
| 126 |
-
GRPO_PROMPTS = None # v3: None = use ALL available prompts (was 300 subset in v2)
|
| 127 |
-
GENERAL_MIX_RATIO = 0.0 # v3: set to 0.3 if general_reasoning.jsonl exists (Cocktail Effect paper)
|
| 128 |
-
|
| 129 |
-
# ── Valid enums for reward scoring (unchanged from v2) ────────────────────────
|
| 130 |
-
VALID_SENTIMENTS = {"positive", "negative", "neutral"}
|
| 131 |
-
VALID_CATEGORIES = {
|
| 132 |
-
"delivery_delay", "product_quality", "product_not_received",
|
| 133 |
-
"wrong_product", "seller_communication", "app_issue",
|
| 134 |
-
"price_value", "other", "none",
|
| 135 |
-
}
|
| 136 |
-
VALID_CHURN = {"low", "medium", "high"}
|
| 137 |
-
VALID_REPEAT = {"yes", "no", "maybe"}
|
| 138 |
-
EXTRACTION_FIELDS = [
|
| 139 |
-
"sentiment", "sentiment_score", "churn_risk", "delivery_issue",
|
| 140 |
-
"product_issue", "seller_issue", "main_complaint",
|
| 141 |
-
"complaint_category", "repeat_intent", "would_recommend",
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
SYSTEM_PT = (
|
| 145 |
-
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
|
| 146 |
-
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 150 |
-
# TRAINING HYPERPARAMETERS — v3 fixes (all changes annotated)
|
| 151 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 152 |
-
|
| 153 |
-
# ── Core GRPO params ──────────────────────────────────────────────────────────
|
| 154 |
-
BATCH_SIZE = 4
|
| 155 |
-
GRAD_ACCUM = 1 # v3: reduced from 2. Effective batch = 4×1 = 4 (was 8)
|
| 156 |
-
# With G=4: steps = prompts × 4 / 4 = prompts per epoch
|
| 157 |
-
NUM_GENERATIONS = 4 # v3: reduced from 8 — VRAM tradeoff for longer completions
|
| 158 |
-
# MC-GRPO (2601.22582): G=4 works if noise is mitigated
|
| 159 |
-
SCALE_REWARDS = False # Dr. GRPO (2503.20783): remove std normalization bias
|
| 160 |
-
|
| 161 |
-
# ── v3 CRITICAL FIXES ────────────────────────────────────────────────────────
|
| 162 |
-
|
| 163 |
-
# FIX 1: Temperature — prevent entropy collapse
|
| 164 |
-
# v2 had 0.8. All published GRPO papers use 1.0.
|
| 165 |
-
# Skywork-OR1 (2505.22312) ablation: τ=1.0 vs τ=0.6 → 5-8% better test performance
|
| 166 |
-
TEMPERATURE = 1.0
|
| 167 |
-
|
| 168 |
-
# FIX 2: Completion length — remove the ceiling
|
| 169 |
-
# v2: every single completion hit 2048 ceiling. Model couldn't finish reasoning.
|
| 170 |
-
# Dr. GRPO (2503.20783) §3.1: GRPO length bias inflates wrong answers → ceiling kill gradient
|
| 171 |
-
MAX_COMPLETION_LENGTH = 4096
|
| 172 |
-
|
| 173 |
-
# FIX 3: Learning rate — more aggressive
|
| 174 |
-
# v2: clip_ratio=0 on all steps → updates were too small to matter
|
| 175 |
-
# Dr. GRPO Appendix G: LR=1e-6 (constant). Reasoning-SQL: LR=1e-6 with cosine.
|
| 176 |
-
# We go 2× since v2 showed zero clipping (model can absorb stronger push)
|
| 177 |
-
LEARNING_RATE = 2e-6
|
| 178 |
-
|
| 179 |
-
# FIX 4: β = 0 (no KL penalty)
|
| 180 |
-
# Dr. GRPO (2503.20783) §3.2: KL penalty is unnecessary for rule-based rewards
|
| 181 |
-
# v2 used implicit KL through default β — we explicitly disable it
|
| 182 |
-
BETA = 0.0
|
| 183 |
-
|
| 184 |
-
# ── Training schedule ─────────────────────────────────────────────────────────
|
| 185 |
-
NUM_EPOCHS = 1
|
| 186 |
-
MAX_STEPS = 500 # v3: increased for expanded data; early stopping will halt if needed
|
| 187 |
-
# With ~1400 prompts × 4 gen / (4 batch × 1 accum) = 1400 steps/epoch
|
| 188 |
-
# MAX_STEPS=500 < 1 epoch — early stopping or manual extension
|
| 189 |
-
|
| 190 |
-
# ── Checkpoint + Eval + Early-Stop ────────────────────────────────────────────
|
| 191 |
-
EVAL_SPLIT_RATIO = 0.15
|
| 192 |
-
EVAL_STEPS = 10
|
| 193 |
-
EARLY_STOPPING_PATIENCE = 15 # v3: increased from 10 — gives 150 steps of runway
|
| 194 |
-
EARLY_STOPPING_DELTA = 0.005 # v3: reduced from 0.01 — more sensitive to small gains
|
| 195 |
-
SAVE_STEPS = 10 # v3: more frequent (was 15) — never lose best checkpoint again
|
| 196 |
-
SAVE_TOTAL_LIMIT = 5 # v3: keep more checkpoints (was 3 — lost best in v2)
|
| 197 |
-
WANDB_PROJECT = "tucano2-commerce"
|
| 198 |
-
|
| 199 |
-
# ── Eval callback ─────────────────────────────────────────────────────────────
|
| 200 |
-
EVAL_MAX_SAMPLES = 5
|
| 201 |
-
EVAL_MAX_TOKENS = 4096 # v3: match training max_completion_length (was 2048)
|
| 202 |
-
EVAL_TEMPERATURE = 0.1 # v3: deterministic eval for less noisy signal (was 0.7)
|
| 203 |
-
|
| 204 |
-
# ── Backend ───────────────────────────────────────────────────────────────────
|
| 205 |
-
USE_VLLM = False
|
| 206 |
-
|
| 207 |
-
# ── v3: Zero-advantage noise injection ────────────────────────────────────────
|
| 208 |
-
# Skywork-OR1 (2505.22312) §3.1: zero-std groups destabilize GRPO training
|
| 209 |
-
# When all G completions get identical rewards, the advantage is undefined.
|
| 210 |
-
# Noise injection breaks ties without corrupting the signal.
|
| 211 |
-
ZERO_ADV_NOISE_STD = 0.005 # Small gaussian noise added to zero-variance groups
|
| 212 |
-
|
| 213 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 214 |
-
|
| 215 |
-
# ── Version assertion ─────────────────────────────────────────────────────────
|
| 216 |
-
import trl as _trl
|
| 217 |
-
assert _trl.__version__ == "0.24.0", (
|
| 218 |
-
f"UnslothGRPOTrainer was written for TRL 0.24.0, found {_trl.__version__}.\n"
|
| 219 |
-
"Verify that GRPOTrainer._generate() still exists before proceeding."
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
print("✓ v3 Config loaded")
|
| 223 |
-
print(f" SFT adapter: {SFT_ADAPTER_DIR} (exists: {SFT_ADAPTER_DIR.exists()})")
|
| 224 |
-
print(f" Train data: {DATA_DIR / 'pairs' / 'train.jsonl'} (exists: {(DATA_DIR / 'pairs' / 'train.jsonl').exists()})")
|
| 225 |
-
print(f" Training: batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM}, eff_batch={BATCH_SIZE*GRAD_ACCUM}")
|
| 226 |
-
print(f" GRPO: G={NUM_GENERATIONS}, temp={TEMPERATURE}, LR={LEARNING_RATE}, β={BETA}")
|
| 227 |
-
print(f" Completion: max={MAX_COMPLETION_LENGTH} (v2 was 2048)")
|
| 228 |
-
print(f" ADR: save_steps={SAVE_STEPS}, eval_steps={EVAL_STEPS}, patience={EARLY_STOPPING_PATIENCE}")
|
| 229 |
-
print(f"✓ TRL {_trl.__version__} verified")
|
| 230 |
-
|
| 231 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 232 |
-
# v3 VRAM BUDGET (L4 24GB)
|
| 233 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 234 |
-
# Model (NF4): ~3.5 GB
|
| 235 |
-
# KV Cache (8192 seq): ~3.0 GB
|
| 236 |
-
# Activations: ~4.0 GB
|
| 237 |
-
# Optimizer states: ~3.0 GB
|
| 238 |
-
# Generations (4×4096): ~8.0 GB
|
| 239 |
-
# ─────────────────────────────────
|
| 240 |
-
# Estimated total: ~21.5 GB
|
| 241 |
-
# Headroom: ~2.5 GB
|
| 242 |
-
#
|
| 243 |
-
# If OOM: reduce MAX_COMPLETION_LENGTH to 3072 first, then 2560.
|
| 244 |
-
# Do NOT reduce NUM_GENERATIONS below 4 — GRPO needs variance.
|
| 245 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 246 |
-
```
|
| 247 |
-
|
| 248 |
-
---
|
| 249 |
-
|
| 250 |
-
## Cell 4: Load SFT Adapter
|
| 251 |
-
|
| 252 |
-
```python
|
| 253 |
-
print("Loading SFT adapter...")
|
| 254 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 255 |
-
model_name=str(SFT_ADAPTER_DIR),
|
| 256 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 257 |
-
load_in_4bit=True,
|
| 258 |
-
dtype=None,
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
if tokenizer.pad_token is None:
|
| 262 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 263 |
-
|
| 264 |
-
# Load chat template from base model (SFT adapter doesn't save it)
|
| 265 |
-
from transformers import AutoTokenizer
|
| 266 |
-
base_tok = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 267 |
-
tokenizer.chat_template = base_tok.chat_template
|
| 268 |
-
del base_tok
|
| 269 |
-
|
| 270 |
-
# v2: Force KV cache — Unsloth patching may reset this
|
| 271 |
-
model.config.use_cache = True
|
| 272 |
-
model.generation_config.use_cache = True
|
| 273 |
-
|
| 274 |
-
print(f"✓ Model loaded on {model.device}")
|
| 275 |
-
print(f" use_cache: {model.config.use_cache}")
|
| 276 |
-
print(f" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M")
|
| 277 |
-
print(f" Chat template: {tokenizer.chat_template[:50]}...")
|
| 278 |
-
```
|
| 279 |
-
|
| 280 |
-
---
|
| 281 |
-
|
| 282 |
-
## Cell 5: Single Inference Test
|
| 283 |
-
|
| 284 |
-
**Gate:** Does the model close `</think>` and produce an answer within 4096 tokens?
|
| 285 |
-
|
| 286 |
-
```python
|
| 287 |
-
FastLanguageModel.for_inference(model)
|
| 288 |
-
|
| 289 |
-
test_msgs = [
|
| 290 |
-
{"role": "system", "content": SYSTEM_PT},
|
| 291 |
-
{"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
|
| 292 |
-
]
|
| 293 |
-
text = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)
|
| 294 |
-
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 295 |
-
|
| 296 |
-
t0 = time.time()
|
| 297 |
-
outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
|
| 298 |
-
elapsed = time.time() - t0
|
| 299 |
-
|
| 300 |
-
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 301 |
-
gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
|
| 302 |
-
|
| 303 |
-
print(f"Generation time: {elapsed:.1f}s ({gen_tokens} tokens, {gen_tokens/elapsed:.1f} tok/s)")
|
| 304 |
-
print(f"Response length: {len(response)} chars, {gen_tokens} tokens")
|
| 305 |
-
print(f"Hit ceiling: {gen_tokens >= MAX_COMPLETION_LENGTH}") # v3: should NOT hit ceiling with 4096
|
| 306 |
-
print(f"closed_think: {'</think>' in response}")
|
| 307 |
-
print(f"\n{'='*60}")
|
| 308 |
-
print(response[:800])
|
| 309 |
-
```
|
| 310 |
-
|
| 311 |
-
---
|
| 312 |
-
|
| 313 |
-
## Cell 5b: KV Cache Diagnostic
|
| 314 |
-
|
| 315 |
-
```python
|
| 316 |
-
import time
|
| 317 |
-
FastLanguageModel.for_inference(model)
|
| 318 |
-
|
| 319 |
-
_kv_msgs = [{"role": "system", "content": SYSTEM_PT},
|
| 320 |
-
{"role": "user", "content": "Qual a categoria de reclamação mais frequente?"}]
|
| 321 |
-
_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)
|
| 322 |
-
_kv_inputs = tokenizer(_kv_text, return_tensors="pt").to(model.device)
|
| 323 |
-
|
| 324 |
-
_token_times, _past, _generated = [], None, _kv_inputs["input_ids"]
|
| 325 |
-
with torch.no_grad():
|
| 326 |
-
for _step in range(50):
|
| 327 |
-
_t0 = time.time()
|
| 328 |
-
seq_len = _generated.shape[1]
|
| 329 |
-
if _past is None:
|
| 330 |
-
_position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)
|
| 331 |
-
else:
|
| 332 |
-
_position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)
|
| 333 |
-
_out = model(
|
| 334 |
-
input_ids=_generated[:, -1:] if _past else _generated,
|
| 335 |
-
position_ids=_position_ids,
|
| 336 |
-
attention_mask=torch.ones(1, seq_len, device=model.device),
|
| 337 |
-
past_key_values=_past,
|
| 338 |
-
use_cache=True,
|
| 339 |
-
return_dict=True,
|
| 340 |
-
)
|
| 341 |
-
_past = _out.past_key_values
|
| 342 |
-
_next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
| 343 |
-
_generated = torch.cat([_generated, _next], dim=1)
|
| 344 |
-
_token_times.append(time.time() - _t0)
|
| 345 |
-
|
| 346 |
-
_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)
|
| 347 |
-
print(f"First 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}")
|
| 348 |
-
print(f"Last 5 tok : {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}")
|
| 349 |
-
print(f"Ratio last/first: {_ratio:.1f}x")
|
| 350 |
-
if _ratio < 3:
|
| 351 |
-
print("✓ KV cache is working correctly")
|
| 352 |
-
elif _ratio < 6:
|
| 353 |
-
print("⚠ KV cache may be degraded — check model.config.use_cache")
|
| 354 |
-
else:
|
| 355 |
-
print("✗ KV cache BROKEN — GRPO generation will be catastrophically slow.")
|
| 356 |
-
|
| 357 |
-
del _past, _generated, _kv_inputs, _token_times, _out
|
| 358 |
-
gc.collect()
|
| 359 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 360 |
-
```
|
| 361 |
-
|
| 362 |
-
---
|
| 363 |
-
|
| 364 |
-
## Cell 6: Reward Functions v3
|
| 365 |
-
|
| 366 |
-
**v3 changes:**
|
| 367 |
-
- Staged reward design: format → partial content → full task (Reasoning-SQL, 2503.23157)
|
| 368 |
-
- Zero-advantage noise injection (Skywork-OR1, 2505.22312)
|
| 369 |
-
- Extraction reward redesigned for completion-length-friendly scoring
|
| 370 |
-
|
| 371 |
-
```python
|
| 372 |
-
def strip_think(text: str) -> str:
|
| 373 |
-
"""Remove <think>...</think> block, return the answer portion."""
|
| 374 |
-
return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
def has_think_block(text: str) -> bool:
|
| 378 |
-
"""Check if text contains a non-empty <think> block."""
|
| 379 |
-
return bool(re.search(r"<think>.+</think>", text, flags=re.DOTALL))
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
def _classify_task_type(prompt_text: str) -> str:
|
| 383 |
-
"""Classify prompt into task type by keywords."""
|
| 384 |
-
p = prompt_text.lower()
|
| 385 |
-
if "retorne um objeto json" in p or "extraia dados" in p:
|
| 386 |
-
return "extraction"
|
| 387 |
-
elif "notificação push" in p or "notificação de reengajamento" in p:
|
| 388 |
-
return "push"
|
| 389 |
-
elif "perfil do cliente" in p:
|
| 390 |
-
return "insights"
|
| 391 |
-
else:
|
| 392 |
-
return "sql_qa"
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
def _json_similarity(text: str) -> float:
|
| 396 |
-
"""Rough heuristic: how JSON-like is this text? 0.0 to 1.0."""
|
| 397 |
-
text = text.strip()
|
| 398 |
-
if not text:
|
| 399 |
-
return 0.0
|
| 400 |
-
score = 0.0
|
| 401 |
-
if text.startswith("{") and text.endswith("}"):
|
| 402 |
-
score += 0.5
|
| 403 |
-
if '"' in text:
|
| 404 |
-
score += 0.2
|
| 405 |
-
if ":" in text:
|
| 406 |
-
score += 0.2
|
| 407 |
-
if "," in text:
|
| 408 |
-
score += 0.1
|
| 409 |
-
return min(score, 1.0)
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def _string_similarity(a: str, b: str) -> float:
|
| 413 |
-
"""Simple Jaccard-like similarity for short strings. 0.0 to 1.0."""
|
| 414 |
-
if not a or not b:
|
| 415 |
-
return 0.0
|
| 416 |
-
a_set = set(a.split())
|
| 417 |
-
b_set = set(b.split())
|
| 418 |
-
intersection = len(a_set & b_set)
|
| 419 |
-
union = len(a_set | b_set)
|
| 420 |
-
return intersection / union if union > 0 else 0.0
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 424 |
-
# v3 STAGED REWARD DESIGN
|
| 425 |
-
# Reference: Reasoning-SQL (2503.23157) §3.2
|
| 426 |
-
#
|
| 427 |
-
# Each reward function scores THREE stages independently:
|
| 428 |
-
# Stage 1 — FORMAT (0.0–0.2): Is the output well-structured?
|
| 429 |
-
# Stage 2 — PARTIAL (0.0–0.3): Are some content elements correct?
|
| 430 |
-
# Stage 3 — TASK (0.0–0.5): Is the full task completed correctly?
|
| 431 |
-
#
|
| 432 |
-
# Format rewards converge first (easy to learn), which stabilizes training
|
| 433 |
-
# and enables the model to then learn harder task-specific skills.
|
| 434 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
def reward_extraction(completion: str) -> float:
|
| 438 |
-
"""Staged reward for structured extraction (max 1.0)."""
|
| 439 |
-
answer = strip_think(completion)
|
| 440 |
-
|
| 441 |
-
# ── Stage 1: FORMAT (max 0.2) ─────────────────────────────────────────────
|
| 442 |
-
r_format = 0.0
|
| 443 |
-
if has_think_block(completion):
|
| 444 |
-
r_format += 0.1 # Used reasoning
|
| 445 |
-
|
| 446 |
-
try:
|
| 447 |
-
data = json.loads(answer)
|
| 448 |
-
if isinstance(data, dict):
|
| 449 |
-
r_format += 0.1 # Valid JSON object
|
| 450 |
-
except (json.JSONDecodeError, TypeError):
|
| 451 |
-
r_format += 0.05 * _json_similarity(answer)
|
| 452 |
-
return min(r_format, 0.2)
|
| 453 |
-
|
| 454 |
-
if not isinstance(data, dict):
|
| 455 |
-
return min(r_format, 0.2)
|
| 456 |
-
|
| 457 |
-
# ── Stage 2: PARTIAL CONTENT (max 0.3) ────────────────────────────────────
|
| 458 |
-
r_partial = 0.0
|
| 459 |
-
|
| 460 |
-
present = sum(1 for f in EXTRACTION_FIELDS if f in data)
|
| 461 |
-
r_partial += 0.15 * (present / len(EXTRACTION_FIELDS))
|
| 462 |
-
|
| 463 |
-
type_checks = 0
|
| 464 |
-
type_total = 0
|
| 465 |
-
for field in EXTRACTION_FIELDS:
|
| 466 |
-
if field not in data:
|
| 467 |
-
continue
|
| 468 |
-
type_total += 1
|
| 469 |
-
val = data[field]
|
| 470 |
-
if field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
|
| 471 |
-
if isinstance(val, bool):
|
| 472 |
-
type_checks += 1
|
| 473 |
-
elif field in ("sentiment_score",):
|
| 474 |
-
if isinstance(val, (int, float)):
|
| 475 |
-
type_checks += 1
|
| 476 |
-
elif field in ("main_complaint", "sentiment", "complaint_category", "churn_risk", "repeat_intent"):
|
| 477 |
-
if isinstance(val, str):
|
| 478 |
-
type_checks += 1
|
| 479 |
-
if type_total > 0:
|
| 480 |
-
r_partial += 0.15 * (type_checks / type_total)
|
| 481 |
-
|
| 482 |
-
# ── Stage 3: FULL TASK (max 0.5) ─────────────────────────────────────────
|
| 483 |
-
r_task = 0.0
|
| 484 |
-
cat_checks = 0
|
| 485 |
-
cat_total = 0
|
| 486 |
-
|
| 487 |
-
checks = [
|
| 488 |
-
("sentiment", lambda v: v in VALID_SENTIMENTS),
|
| 489 |
-
("complaint_category", lambda v: v in VALID_CATEGORIES),
|
| 490 |
-
("churn_risk", lambda v: v in VALID_CHURN),
|
| 491 |
-
("repeat_intent", lambda v: v in VALID_REPEAT),
|
| 492 |
-
("sentiment_score", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),
|
| 493 |
-
]
|
| 494 |
-
for field, validator in checks:
|
| 495 |
-
cat_total += 1
|
| 496 |
-
if field in data and validator(data[field]):
|
| 497 |
-
cat_checks += 1
|
| 498 |
-
|
| 499 |
-
for bool_field in ("delivery_issue", "product_issue", "seller_issue", "would_recommend"):
|
| 500 |
-
cat_total += 1
|
| 501 |
-
if bool_field in data and isinstance(data[bool_field], bool):
|
| 502 |
-
cat_checks += 1
|
| 503 |
-
|
| 504 |
-
if cat_total > 0:
|
| 505 |
-
r_task += 0.35 * (cat_checks / cat_total)
|
| 506 |
-
|
| 507 |
-
if "main_complaint" in data and isinstance(data["main_complaint"], str):
|
| 508 |
-
complaint = data["main_complaint"].strip()
|
| 509 |
-
if len(complaint) > 10:
|
| 510 |
-
r_task += 0.15
|
| 511 |
-
|
| 512 |
-
return min(r_format + r_partial + r_task, 1.0)
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
def reward_sql_qa(completion: str) -> float:
|
| 516 |
-
"""Staged reward for SQL Q&A (max 1.0)."""
|
| 517 |
-
answer = strip_think(completion)
|
| 518 |
-
|
| 519 |
-
# ── Stage 1: FORMAT (max 0.2)
|
| 520 |
-
r_format = 0.0
|
| 521 |
-
if has_think_block(completion):
|
| 522 |
-
r_format += 0.1
|
| 523 |
-
if "```" in answer or re.search(r"SELECT|FROM", answer, re.IGNORECASE):
|
| 524 |
-
r_format += 0.1
|
| 525 |
-
|
| 526 |
-
# ── Stage 2: PARTIAL (max 0.3)
|
| 527 |
-
r_partial = 0.0
|
| 528 |
-
sql_keywords = r"SELECT|FROM|WHERE|GROUP BY|ORDER BY|COUNT|SUM|AVG|JOIN|HAVING"
|
| 529 |
-
matches = len(re.findall(sql_keywords, answer, re.IGNORECASE))
|
| 530 |
-
r_partial += min(0.15, 0.03 * matches)
|
| 531 |
-
numbers = re.findall(r"\d+(?:[.,]\d+)?", answer)
|
| 532 |
-
r_partial += min(0.15, 0.03 * len(numbers))
|
| 533 |
-
|
| 534 |
-
# ── Stage 3: TASK (max 0.5)
|
| 535 |
-
r_task = 0.0
|
| 536 |
-
length = len(answer)
|
| 537 |
-
if 50 <= length <= 600:
|
| 538 |
-
r_task += 0.25
|
| 539 |
-
elif length > 0:
|
| 540 |
-
r_task += 0.25 * max(0, 1 - abs(length - 325) / 275)
|
| 541 |
-
explanation_markers = ["para ", "porque", "resultado", "mostra", "indica", "análise"]
|
| 542 |
-
expl_matches = sum(1 for w in explanation_markers if w in answer.lower())
|
| 543 |
-
r_task += min(0.25, 0.05 * expl_matches)
|
| 544 |
-
|
| 545 |
-
return min(r_format + r_partial + r_task, 1.0)
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
def reward_insights(completion: str) -> float:
|
| 549 |
-
"""Staged reward for insights (max 1.0)."""
|
| 550 |
-
answer = strip_think(completion)
|
| 551 |
-
|
| 552 |
-
# ── Stage 1: FORMAT (max 0.2)
|
| 553 |
-
r_format = 0.0
|
| 554 |
-
if has_think_block(completion):
|
| 555 |
-
r_format += 0.1
|
| 556 |
-
structure_marks = len(re.findall(r"^[-•*]\s|^\d+[.)]\s|^#{1,3}\s", answer, re.MULTILINE))
|
| 557 |
-
r_format += min(0.1, 0.02 * structure_marks)
|
| 558 |
-
|
| 559 |
-
# ── Stage 2: PARTIAL (max 0.3)
|
| 560 |
-
r_partial = 0.0
|
| 561 |
-
length = len(answer)
|
| 562 |
-
if 100 <= length <= 1200:
|
| 563 |
-
r_partial += 0.15
|
| 564 |
-
elif length > 0:
|
| 565 |
-
r_partial += 0.15 * max(0, 1 - abs(length - 650) / 550)
|
| 566 |
-
pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua|cliente|produto", answer, re.IGNORECASE)
|
| 567 |
-
r_partial += min(0.15, 0.01 * len(pt_markers))
|
| 568 |
-
|
| 569 |
-
# ── Stage 3: TASK (max 0.5)
|
| 570 |
-
r_task = 0.0
|
| 571 |
-
action_words = ["recomend", "implement", "melhor", "reduzir", "aumentar",
|
| 572 |
-
"priorizar", "investir", "otimizar", "estratégi", "suger",
|
| 573 |
-
"consider", "ação", "plano"]
|
| 574 |
-
matches = sum(1 for w in action_words if w in answer.lower())
|
| 575 |
-
r_task += min(0.3, 0.06 * matches)
|
| 576 |
-
data_refs = len(re.findall(r"\d+%|R\$\s*\d|média|percentual|comparad|taxa", answer, re.IGNORECASE))
|
| 577 |
-
r_task += min(0.2, 0.04 * data_refs)
|
| 578 |
-
|
| 579 |
-
return min(r_format + r_partial + r_task, 1.0)
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
def reward_push(completion: str) -> float:
|
| 583 |
-
"""Staged reward for push notifications (max 1.0)."""
|
| 584 |
-
answer = strip_think(completion)
|
| 585 |
-
if not answer:
|
| 586 |
-
return 0.0
|
| 587 |
-
|
| 588 |
-
# ── Stage 1: FORMAT (max 0.2)
|
| 589 |
-
r_format = 0.0
|
| 590 |
-
if has_think_block(completion):
|
| 591 |
-
r_format += 0.05
|
| 592 |
-
length = len(answer)
|
| 593 |
-
if length <= 160:
|
| 594 |
-
r_format += 0.15
|
| 595 |
-
elif length <= 300:
|
| 596 |
-
r_format += 0.1
|
| 597 |
-
else:
|
| 598 |
-
r_format += 0.05
|
| 599 |
-
|
| 600 |
-
# ── Stage 2: PARTIAL (max 0.3)
|
| 601 |
-
r_partial = 0.0
|
| 602 |
-
pt_markers = re.findall(r"[ãçéêóúâõ]|você|para|como|seu|sua", answer, re.IGNORECASE)
|
| 603 |
-
r_partial += min(0.15, 0.02 * len(pt_markers))
|
| 604 |
-
if re.search(r"[!?]|[\U0001F600-\U0001F64F]|[\U0001F300-\U0001F5FF]", answer):
|
| 605 |
-
r_partial += 0.05
|
| 606 |
-
if len(answer.split()) >= 5:
|
| 607 |
-
r_partial += 0.1
|
| 608 |
-
|
| 609 |
-
# ── Stage 3: TASK (max 0.5)
|
| 610 |
-
r_task = 0.0
|
| 611 |
-
if length <= 120:
|
| 612 |
-
r_task += 0.25
|
| 613 |
-
else:
|
| 614 |
-
r_task += 0.25 * max(0, 1 - (length - 120) / 120)
|
| 615 |
-
generic_phrases = [
|
| 616 |
-
"olá! como podemos ajudar", "obrigado pela sua compra",
|
| 617 |
-
"seu pedido foi confirmado", "agradecemos sua preferência",
|
| 618 |
-
]
|
| 619 |
-
max_similarity = max(_string_similarity(answer.lower(), g) for g in generic_phrases)
|
| 620 |
-
r_task += 0.25 * (1 - max_similarity)
|
| 621 |
-
|
| 622 |
-
return min(r_format + r_partial + r_task, 1.0)
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
def commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:
|
| 626 |
-
"""
|
| 627 |
-
Master reward function v3: dispatches by task type + zero-advantage noise.
|
| 628 |
-
"""
|
| 629 |
-
rewards = []
|
| 630 |
-
for completion, prompt in zip(completions, prompts):
|
| 631 |
-
if isinstance(completion, list):
|
| 632 |
-
comp_text = completion[-1]["content"] if completion else ""
|
| 633 |
-
else:
|
| 634 |
-
comp_text = str(completion)
|
| 635 |
-
|
| 636 |
-
if isinstance(prompt, list):
|
| 637 |
-
prompt_text = " ".join(m.get("content", "") for m in prompt)
|
| 638 |
-
else:
|
| 639 |
-
prompt_text = str(prompt)
|
| 640 |
-
|
| 641 |
-
task = _classify_task_type(prompt_text)
|
| 642 |
-
|
| 643 |
-
if task == "extraction":
|
| 644 |
-
rewards.append(reward_extraction(comp_text))
|
| 645 |
-
elif task == "sql_qa":
|
| 646 |
-
rewards.append(reward_sql_qa(comp_text))
|
| 647 |
-
elif task == "insights":
|
| 648 |
-
rewards.append(reward_insights(comp_text))
|
| 649 |
-
elif task == "push":
|
| 650 |
-
rewards.append(reward_push(comp_text))
|
| 651 |
-
else:
|
| 652 |
-
r = 0.15 if has_think_block(comp_text) else 0.0
|
| 653 |
-
r += 0.2 if comp_text.strip() else 0.0
|
| 654 |
-
rewards.append(r)
|
| 655 |
-
|
| 656 |
-
# ── v3: Zero-advantage noise injection ────────────────────────────────────
|
| 657 |
-
if ZERO_ADV_NOISE_STD > 0 and NUM_GENERATIONS > 1:
|
| 658 |
-
for i in range(0, len(rewards), NUM_GENERATIONS):
|
| 659 |
-
group = rewards[i:i+NUM_GENERATIONS]
|
| 660 |
-
if len(group) < 2:
|
| 661 |
-
continue
|
| 662 |
-
if max(group) - min(group) < 0.001:
|
| 663 |
-
for j in range(i, min(i+NUM_GENERATIONS, len(rewards))):
|
| 664 |
-
rewards[j] += random.gauss(0, ZERO_ADV_NOISE_STD)
|
| 665 |
-
|
| 666 |
-
return rewards
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
print("✓ v3 Reward functions defined (staged: format → partial → task)")
|
| 670 |
-
```
|
| 671 |
-
|
| 672 |
-
---
|
| 673 |
-
|
| 674 |
-
## Cell 7: Reward Calibration
|
| 675 |
-
|
| 676 |
-
**Gate:** Verify reward variance > 0. Compare v3 scoring to v2 calibration (mean=0.38).
|
| 677 |
-
|
| 678 |
-
```python
|
| 679 |
-
train_path = DATA_DIR / "pairs" / "train.jsonl"
|
| 680 |
-
|
| 681 |
-
by_type = {"extraction": [], "sql_qa": [], "insights": [], "push": []}
|
| 682 |
-
with open(train_path) as f:
|
| 683 |
-
for line in f:
|
| 684 |
-
row = json.loads(line)
|
| 685 |
-
convs = row["conversations"]
|
| 686 |
-
prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
|
| 687 |
-
if not prompt_msgs:
|
| 688 |
-
continue
|
| 689 |
-
user_text = " ".join(m["content"] for m in prompt_msgs if m["role"] == "user")
|
| 690 |
-
task = _classify_task_type(user_text)
|
| 691 |
-
by_type[task].append(prompt_msgs)
|
| 692 |
-
|
| 693 |
-
print(f"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}")
|
| 694 |
-
|
| 695 |
-
rng = random.Random(42)
|
| 696 |
-
cal_samples = []
|
| 697 |
-
for task_type in ["extraction", "extraction", "sql_qa", "sql_qa", "insights", "insights", "push", "push"]:
|
| 698 |
-
cal_samples.append(rng.choice(by_type[task_type]))
|
| 699 |
-
|
| 700 |
-
FastLanguageModel.for_inference(model)
|
| 701 |
-
print(f"\nReward calibration v3 ({len(cal_samples)} samples):")
|
| 702 |
-
print("-" * 70)
|
| 703 |
-
|
| 704 |
-
cal_rewards = []
|
| 705 |
-
for i, msgs in enumerate(cal_samples):
|
| 706 |
-
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 707 |
-
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 708 |
-
outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.7, do_sample=True)
|
| 709 |
-
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 710 |
-
gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
|
| 711 |
-
|
| 712 |
-
r = commerce_reward_fn([response], [text])[0]
|
| 713 |
-
cal_rewards.append(r)
|
| 714 |
-
hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
|
| 715 |
-
has_answer = "</think>" in response
|
| 716 |
-
answer_preview = strip_think(response)[:100] if has_answer else "[stuck in <think>]"
|
| 717 |
-
task = _classify_task_type(text)
|
| 718 |
-
print(f" [{task:12s}] reward={r:.2f} | tokens={gen_tokens:4d} | ceiling={'⚠️ HIT' if hit_ceiling else 'ok':6s} | {answer_preview}")
|
| 719 |
-
|
| 720 |
-
print(f"\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}")
|
| 721 |
-
print(f"v2 calibration was: Mean=0.38, Min=0.02, Max=0.70")
|
| 722 |
-
print(f"Variance > 0: {len(set(cal_rewards)) > 1}")
|
| 723 |
-
```
|
| 724 |
-
|
| 725 |
-
---
|
| 726 |
-
|
| 727 |
-
## Cell 8: Dataset Preparation v3
|
| 728 |
-
|
| 729 |
-
```python
|
| 730 |
-
from datasets import Dataset
|
| 731 |
-
|
| 732 |
-
def prepare_grpo_datasets_v3(n_prompts=GRPO_PROMPTS, eval_ratio=EVAL_SPLIT_RATIO,
|
| 733 |
-
general_mix=GENERAL_MIX_RATIO, seed=42):
|
| 734 |
-
rng = random.Random(seed)
|
| 735 |
-
|
| 736 |
-
train_pools = {}
|
| 737 |
-
eval_records = []
|
| 738 |
-
for task, pool in by_type.items():
|
| 739 |
-
shuffled = pool.copy()
|
| 740 |
-
rng.shuffle(shuffled)
|
| 741 |
-
n_eval = max(1, int(len(shuffled) * eval_ratio))
|
| 742 |
-
eval_records.extend(shuffled[:n_eval])
|
| 743 |
-
train_pools[task] = shuffled[n_eval:]
|
| 744 |
-
|
| 745 |
-
if n_prompts is None:
|
| 746 |
-
train_records = []
|
| 747 |
-
for task, pool in train_pools.items():
|
| 748 |
-
train_records.extend(pool)
|
| 749 |
-
rng.shuffle(train_records)
|
| 750 |
-
else:
|
| 751 |
-
targets = {
|
| 752 |
-
"extraction": int(n_prompts * 0.4),
|
| 753 |
-
"sql_qa": int(n_prompts * 0.4),
|
| 754 |
-
"insights": int(n_prompts * 0.1),
|
| 755 |
-
"push": int(n_prompts * 0.1),
|
| 756 |
-
}
|
| 757 |
-
train_records = []
|
| 758 |
-
for task, target_n in targets.items():
|
| 759 |
-
pool = train_pools[task]
|
| 760 |
-
n = min(target_n, len(pool))
|
| 761 |
-
train_records.extend(rng.sample(pool, n))
|
| 762 |
-
rng.shuffle(train_records)
|
| 763 |
-
|
| 764 |
-
general_path = DATA_DIR / "pairs" / "general_reasoning.jsonl"
|
| 765 |
-
if general_mix > 0 and general_path.exists():
|
| 766 |
-
general_records = []
|
| 767 |
-
with open(general_path) as f:
|
| 768 |
-
for line in f:
|
| 769 |
-
row = json.loads(line)
|
| 770 |
-
convs = row["conversations"]
|
| 771 |
-
prompt_msgs = [m for m in convs if m["role"] in ("system", "user")]
|
| 772 |
-
if prompt_msgs:
|
| 773 |
-
general_records.append(prompt_msgs)
|
| 774 |
-
n_general = int(len(train_records) * general_mix / (1 - general_mix))
|
| 775 |
-
n_general = min(n_general, len(general_records))
|
| 776 |
-
if n_general > 0:
|
| 777 |
-
train_records.extend(rng.sample(general_records, n_general))
|
| 778 |
-
rng.shuffle(train_records)
|
| 779 |
-
print(f" Cocktail Effect: added {n_general} general reasoning samples ({general_mix:.0%} mix)")
|
| 780 |
-
elif general_mix > 0:
|
| 781 |
-
print(f" ⚠️ general_reasoning.jsonl not found — skipping mix")
|
| 782 |
-
|
| 783 |
-
task_dist = {}
|
| 784 |
-
for record in train_records:
|
| 785 |
-
user_text = " ".join(m["content"] for m in record if m["role"] == "user")
|
| 786 |
-
task = _classify_task_type(user_text)
|
| 787 |
-
task_dist[task] = task_dist.get(task, 0) + 1
|
| 788 |
-
|
| 789 |
-
n_domain = len(train_records)
|
| 790 |
-
steps_per_epoch = n_domain * NUM_GENERATIONS // (BATCH_SIZE * GRAD_ACCUM)
|
| 791 |
-
|
| 792 |
-
print(f"v3 Dataset split (eval_ratio={eval_ratio}):")
|
| 793 |
-
print(f" train : {n_domain} prompts")
|
| 794 |
-
print(f" eval : {len(eval_records)} prompts")
|
| 795 |
-
print(f" distribution: {', '.join(f'{k}={v}' for k, v in sorted(task_dist.items()))}")
|
| 796 |
-
print(f" steps/epoch: {n_domain} × {NUM_GENERATIONS} / ({BATCH_SIZE} × {GRAD_ACCUM}) = {steps_per_epoch}")
|
| 797 |
-
print(f" MAX_STEPS={MAX_STEPS} → {'< 1 epoch' if MAX_STEPS < steps_per_epoch else f'{MAX_STEPS/steps_per_epoch:.1f} epochs'}")
|
| 798 |
-
|
| 799 |
-
train_ds = Dataset.from_list([{"prompt": msgs} for msgs in train_records])
|
| 800 |
-
eval_ds = Dataset.from_list([{"prompt": msgs} for msgs in eval_records])
|
| 801 |
-
return train_ds, eval_ds
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
train_dataset, eval_dataset = prepare_grpo_datasets_v3()
|
| 805 |
-
dataset = train_dataset
|
| 806 |
-
print(f"\n✓ v3 Datasets ready: train={len(train_dataset)}, eval={len(eval_dataset)}")
|
| 807 |
-
```
|
| 808 |
-
|
| 809 |
-
---
|
| 810 |
-
|
| 811 |
-
## Cell 9: Smoke Test
|
| 812 |
-
|
| 813 |
-
**Gate:** Runs 1 step without OOM at new completion length (4096).
|
| 814 |
-
|
| 815 |
-
```python
|
| 816 |
-
from trl import GRPOConfig, GRPOTrainer
|
| 817 |
-
|
| 818 |
-
FastLanguageModel.for_training(model)
|
| 819 |
-
|
| 820 |
-
smoke_config = GRPOConfig(
|
| 821 |
-
output_dir=str(CHECKPOINT_DIR / "smoke"),
|
| 822 |
-
num_generations=NUM_GENERATIONS,
|
| 823 |
-
scale_rewards=SCALE_REWARDS,
|
| 824 |
-
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 825 |
-
max_steps=1,
|
| 826 |
-
num_train_epochs=1,
|
| 827 |
-
temperature=TEMPERATURE,
|
| 828 |
-
per_device_train_batch_size=BATCH_SIZE,
|
| 829 |
-
gradient_accumulation_steps=1,
|
| 830 |
-
learning_rate=LEARNING_RATE,
|
| 831 |
-
fp16=False,
|
| 832 |
-
bf16=True,
|
| 833 |
-
logging_steps=1,
|
| 834 |
-
save_steps=999,
|
| 835 |
-
report_to="none",
|
| 836 |
-
max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
|
| 837 |
-
seed=42,
|
| 838 |
-
remove_unused_columns=False,
|
| 839 |
-
)
|
| 840 |
-
|
| 841 |
-
smoke_trainer = GRPOTrainer(
|
| 842 |
-
model=model,
|
| 843 |
-
reward_funcs=commerce_reward_fn,
|
| 844 |
-
args=smoke_config,
|
| 845 |
-
train_dataset=dataset,
|
| 846 |
-
tokenizer=tokenizer,
|
| 847 |
-
)
|
| 848 |
-
|
| 849 |
-
t0 = time.time()
|
| 850 |
-
smoke_trainer.train()
|
| 851 |
-
step_time = time.time() - t0
|
| 852 |
-
|
| 853 |
-
print(f"\n✓ Smoke test passed!")
|
| 854 |
-
print(f" Step time (grad_accum=1): {step_time:.0f}s")
|
| 855 |
-
print(f" Estimated step time (grad_accum={GRAD_ACCUM}): {step_time * GRAD_ACCUM:.0f}s")
|
| 856 |
-
print(f" VRAM peak: {torch.cuda.max_memory_allocated()/1e9:.1f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB")
|
| 857 |
-
|
| 858 |
-
vram_used = torch.cuda.max_memory_allocated() / 1e9
|
| 859 |
-
vram_total = torch.cuda.get_device_properties(0).total_mem / 1e9
|
| 860 |
-
if vram_used > vram_total * 0.95:
|
| 861 |
-
print(f"\n⚠️ VRAM at {vram_used/vram_total:.0%} — dangerously close to OOM")
|
| 862 |
-
print(f" Option 1: Reduce MAX_COMPLETION_LENGTH to 3072")
|
| 863 |
-
print(f" Option 2: Reduce BATCH_SIZE to 2 (increase GRAD_ACCUM to 2)")
|
| 864 |
-
|
| 865 |
-
del smoke_trainer
|
| 866 |
-
gc.collect(); torch.cuda.empty_cache()
|
| 867 |
-
```
|
| 868 |
-
|
| 869 |
-
---
|
| 870 |
-
|
| 871 |
-
## Cell 10: Probe Run (3 steps)
|
| 872 |
-
|
| 873 |
-
```python
|
| 874 |
-
FastLanguageModel.for_training(model)
|
| 875 |
-
|
| 876 |
-
probe_config = GRPOConfig(
|
| 877 |
-
output_dir=str(CHECKPOINT_DIR / "probe"),
|
| 878 |
-
num_generations=NUM_GENERATIONS,
|
| 879 |
-
scale_rewards=SCALE_REWARDS,
|
| 880 |
-
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 881 |
-
max_steps=3,
|
| 882 |
-
temperature=TEMPERATURE,
|
| 883 |
-
num_train_epochs=NUM_EPOCHS,
|
| 884 |
-
per_device_train_batch_size=BATCH_SIZE,
|
| 885 |
-
gradient_accumulation_steps=GRAD_ACCUM,
|
| 886 |
-
learning_rate=LEARNING_RATE,
|
| 887 |
-
warmup_ratio=0.1,
|
| 888 |
-
lr_scheduler_type="cosine",
|
| 889 |
-
fp16=False,
|
| 890 |
-
bf16=True,
|
| 891 |
-
logging_steps=1,
|
| 892 |
-
disable_tqdm=True,
|
| 893 |
-
logging_first_step=True,
|
| 894 |
-
save_steps=999,
|
| 895 |
-
report_to="none",
|
| 896 |
-
max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
|
| 897 |
-
seed=42,
|
| 898 |
-
remove_unused_columns=False,
|
| 899 |
-
)
|
| 900 |
-
|
| 901 |
-
probe_trainer = GRPOTrainer(
|
| 902 |
-
model=model,
|
| 903 |
-
reward_funcs=commerce_reward_fn,
|
| 904 |
-
args=probe_config,
|
| 905 |
-
train_dataset=dataset,
|
| 906 |
-
tokenizer=tokenizer,
|
| 907 |
-
)
|
| 908 |
-
|
| 909 |
-
t0 = time.time()
|
| 910 |
-
result = probe_trainer.train()
|
| 911 |
-
elapsed = time.time() - t0
|
| 912 |
-
|
| 913 |
-
print(f"\n✓ Probe complete in {elapsed:.0f}s ({elapsed/3:.0f}s/step)")
|
| 914 |
-
print(f" Train loss: {result.training_loss:.6f}")
|
| 915 |
-
print(f" Estimated full run ({MAX_STEPS} steps): {elapsed/3 * MAX_STEPS / 3600:.1f}h")
|
| 916 |
-
|
| 917 |
-
if abs(result.training_loss) < 1e-6:
|
| 918 |
-
print(" ⚠️ Loss is near-zero — reward variance may be insufficient")
|
| 919 |
-
else:
|
| 920 |
-
print(" ✓ Loss is non-zero — GRPO has gradient signal")
|
| 921 |
-
|
| 922 |
-
del probe_trainer
|
| 923 |
-
gc.collect(); torch.cuda.empty_cache()
|
| 924 |
-
```
|
| 925 |
-
|
| 926 |
-
---
|
| 927 |
-
|
| 928 |
-
## Cell 11: Full Training Run v3
|
| 929 |
-
|
| 930 |
-
```python
|
| 931 |
-
import wandb
|
| 932 |
-
|
| 933 |
-
_wandb_key = os.environ.get("WANDB_API_KEY", "").strip()
|
| 934 |
-
if not _wandb_key:
|
| 935 |
-
raise EnvironmentError("WANDB_API_KEY is not set.")
|
| 936 |
-
wandb.login(key=_wandb_key, relogin=True)
|
| 937 |
-
print(f"✓ W&B authenticated")
|
| 938 |
-
```
|
| 939 |
-
|
| 940 |
-
```python
|
| 941 |
-
import shutil
|
| 942 |
-
import torch
|
| 943 |
-
from transformers import TrainerCallback
|
| 944 |
-
from trl import GRPOConfig, GRPOTrainer
|
| 945 |
-
|
| 946 |
-
wandb.init(
|
| 947 |
-
project=WANDB_PROJECT,
|
| 948 |
-
name=f"grpo-v3-l4-{time.strftime('%Y%m%d-%H%M')}",
|
| 949 |
-
config={
|
| 950 |
-
"model_id": MODEL_ID,
|
| 951 |
-
"version": "v3",
|
| 952 |
-
"temperature": TEMPERATURE,
|
| 953 |
-
"max_completion_length": MAX_COMPLETION_LENGTH,
|
| 954 |
-
"num_generations": NUM_GENERATIONS,
|
| 955 |
-
"learning_rate": LEARNING_RATE,
|
| 956 |
-
"beta": BETA,
|
| 957 |
-
"batch_size": BATCH_SIZE,
|
| 958 |
-
"grad_accum": GRAD_ACCUM,
|
| 959 |
-
"max_steps": MAX_STEPS,
|
| 960 |
-
"scale_rewards": SCALE_REWARDS,
|
| 961 |
-
"save_steps": SAVE_STEPS,
|
| 962 |
-
"eval_steps": EVAL_STEPS,
|
| 963 |
-
"eval_max_samples": EVAL_MAX_SAMPLES,
|
| 964 |
-
"eval_max_tokens": EVAL_MAX_TOKENS,
|
| 965 |
-
"eval_temperature": EVAL_TEMPERATURE,
|
| 966 |
-
"patience": EARLY_STOPPING_PATIENCE,
|
| 967 |
-
"delta": EARLY_STOPPING_DELTA,
|
| 968 |
-
"train_prompts": len(train_dataset),
|
| 969 |
-
"eval_prompts": len(eval_dataset),
|
| 970 |
-
"zero_adv_noise_std": ZERO_ADV_NOISE_STD,
|
| 971 |
-
"general_mix_ratio": GENERAL_MIX_RATIO,
|
| 972 |
-
"_ref_temperature": "Skywork-OR1 (2505.22312)",
|
| 973 |
-
"_ref_completion_length": "Dr. GRPO (2503.20783)",
|
| 974 |
-
"_ref_staged_rewards": "Reasoning-SQL (2503.23157)",
|
| 975 |
-
"_ref_zero_adv": "Skywork-OR1 (2505.22312)",
|
| 976 |
-
},
|
| 977 |
-
)
|
| 978 |
-
print(f"✓ W&B run: {wandb.run.url}")
|
| 979 |
-
|
| 980 |
-
FRESH = True
|
| 981 |
-
resume_from = None
|
| 982 |
-
if FRESH and CHECKPOINT_DIR.exists():
|
| 983 |
-
print("FRESH: deleting old checkpoints...")
|
| 984 |
-
shutil.rmtree(CHECKPOINT_DIR)
|
| 985 |
-
elif CHECKPOINT_DIR.exists():
|
| 986 |
-
checkpoints = sorted(
|
| 987 |
-
[d for d in CHECKPOINT_DIR.iterdir()
|
| 988 |
-
if d.is_dir() and d.name.startswith("checkpoint-")],
|
| 989 |
-
key=lambda d: int(d.name.split("-")[-1]),
|
| 990 |
-
)
|
| 991 |
-
if checkpoints:
|
| 992 |
-
resume_from = str(checkpoints[-1])
|
| 993 |
-
print(f"Resuming from: {resume_from}")
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
class UnslothGRPOTrainer(GRPOTrainer):
|
| 997 |
-
"""Wraps generation with Unsloth for_inference()/for_training()."""
|
| 998 |
-
def _generate(self, prompts, images):
|
| 999 |
-
FastLanguageModel.for_inference(self.model)
|
| 1000 |
-
try:
|
| 1001 |
-
result = super()._generate(prompts, images)
|
| 1002 |
-
finally:
|
| 1003 |
-
FastLanguageModel.for_training(self.model)
|
| 1004 |
-
return result
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
class EvalRewardCallback(TrainerCallback):
|
| 1008 |
-
"""v3: deterministic eval, per-task breakdown, patience=15."""
|
| 1009 |
-
def __init__(self, eval_records, reward_fn, patience=EARLY_STOPPING_PATIENCE,
|
| 1010 |
-
delta=EARLY_STOPPING_DELTA):
|
| 1011 |
-
self.eval_records = eval_records
|
| 1012 |
-
self.reward_fn = reward_fn
|
| 1013 |
-
self.patience = patience
|
| 1014 |
-
self.delta = delta
|
| 1015 |
-
self.best_reward = -float("inf")
|
| 1016 |
-
self.no_improve_count = 0
|
| 1017 |
-
|
| 1018 |
-
def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
| 1019 |
-
if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:
|
| 1020 |
-
return control
|
| 1021 |
-
tokenizer = processing_class
|
| 1022 |
-
if tokenizer is None:
|
| 1023 |
-
print("[EvalRewardCallback] WARNING: tokenizer is None, skipping eval")
|
| 1024 |
-
return control
|
| 1025 |
-
|
| 1026 |
-
mean_reward, task_rewards = self._run_eval(model, tokenizer, args)
|
| 1027 |
-
improved = mean_reward > self.best_reward + self.delta
|
| 1028 |
-
status = "↑ improved" if improved else f"↔ no gain ({self.no_improve_count + 1}/{self.patience})"
|
| 1029 |
-
|
| 1030 |
-
log_dict = {
|
| 1031 |
-
"eval/mean_reward": mean_reward,
|
| 1032 |
-
"eval/best_reward": max(self.best_reward, mean_reward),
|
| 1033 |
-
"eval/no_improve_count": self.no_improve_count,
|
| 1034 |
-
}
|
| 1035 |
-
for task, rewards in task_rewards.items():
|
| 1036 |
-
if rewards:
|
| 1037 |
-
log_dict[f"eval/{task}_reward"] = sum(rewards) / len(rewards)
|
| 1038 |
-
wandb.log(log_dict, step=state.global_step)
|
| 1039 |
-
|
| 1040 |
-
print(f"\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}")
|
| 1041 |
-
for task, rewards in task_rewards.items():
|
| 1042 |
-
if rewards:
|
| 1043 |
-
print(f" {task}: {sum(rewards)/len(rewards):.3f} (n={len(rewards)})")
|
| 1044 |
-
|
| 1045 |
-
if improved:
|
| 1046 |
-
self.best_reward = mean_reward
|
| 1047 |
-
self.no_improve_count = 0
|
| 1048 |
-
else:
|
| 1049 |
-
self.no_improve_count += 1
|
| 1050 |
-
if self.no_improve_count >= self.patience:
|
| 1051 |
-
print(f"[EarlyStopping] No improvement ≥ {self.delta} for {self.patience} consecutive evals. Halting.")
|
| 1052 |
-
wandb.log({"early_stop/step": state.global_step}, step=state.global_step)
|
| 1053 |
-
control.should_training_stop = True
|
| 1054 |
-
return control
|
| 1055 |
-
|
| 1056 |
-
def _run_eval(self, model, tokenizer, args):
|
| 1057 |
-
FastLanguageModel.for_inference(model)
|
| 1058 |
-
rewards = []
|
| 1059 |
-
task_rewards = {}
|
| 1060 |
-
subset = self.eval_records[:EVAL_MAX_SAMPLES]
|
| 1061 |
-
for record in subset:
|
| 1062 |
-
msgs = record["prompt"]
|
| 1063 |
-
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 1064 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True,
|
| 1065 |
-
max_length=args.max_prompt_length).to(model.device)
|
| 1066 |
-
with torch.no_grad():
|
| 1067 |
-
out = model.generate(**inputs, max_new_tokens=EVAL_MAX_TOKENS,
|
| 1068 |
-
temperature=EVAL_TEMPERATURE, do_sample=True)
|
| 1069 |
-
resp = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 1070 |
-
r = self.reward_fn([resp], [text])[0]
|
| 1071 |
-
rewards.append(r)
|
| 1072 |
-
user_text = " ".join(m.get("content", "") for m in msgs if m.get("role") == "user")
|
| 1073 |
-
task = _classify_task_type(user_text)
|
| 1074 |
-
task_rewards.setdefault(task, []).append(r)
|
| 1075 |
-
FastLanguageModel.for_training(model)
|
| 1076 |
-
mean = sum(rewards) / len(rewards) if rewards else 0.0
|
| 1077 |
-
return mean, task_rewards
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
class EntropyMonitorCallback(TrainerCallback):
|
| 1081 |
-
"""v3 NEW: Monitor entropy collapse indicators (Skywork-OR1 §4)."""
|
| 1082 |
-
def __init__(self):
|
| 1083 |
-
self.consecutive_ceiling_hits = 0
|
| 1084 |
-
|
| 1085 |
-
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 1086 |
-
if not logs:
|
| 1087 |
-
return
|
| 1088 |
-
step = state.global_step
|
| 1089 |
-
monitor = {}
|
| 1090 |
-
comp_len = logs.get("completion_length", 0)
|
| 1091 |
-
if comp_len > 0:
|
| 1092 |
-
ratio = comp_len / MAX_COMPLETION_LENGTH
|
| 1093 |
-
monitor["monitor/completion_ratio"] = ratio
|
| 1094 |
-
if ratio > 0.95:
|
| 1095 |
-
self.consecutive_ceiling_hits += 1
|
| 1096 |
-
if self.consecutive_ceiling_hits >= 3:
|
| 1097 |
-
print(f"⚠️ Step {step}: Completion ceiling hit {self.consecutive_ceiling_hits} consecutive times.")
|
| 1098 |
-
else:
|
| 1099 |
-
self.consecutive_ceiling_hits = 0
|
| 1100 |
-
reward_std = logs.get("reward_std", logs.get("rewards/commerce_reward_fn/std", 0))
|
| 1101 |
-
if reward_std is not None:
|
| 1102 |
-
monitor["monitor/reward_std"] = reward_std
|
| 1103 |
-
if reward_std < 0.01:
|
| 1104 |
-
print(f"⚠️ Step {step}: reward_std={reward_std:.4f} — near-zero variance")
|
| 1105 |
-
clip_high = logs.get("clip_ratio/high_mean", 0)
|
| 1106 |
-
clip_low = logs.get("clip_ratio/low_mean", 0)
|
| 1107 |
-
if clip_high is not None and clip_low is not None:
|
| 1108 |
-
total_clip = clip_high + abs(clip_low)
|
| 1109 |
-
monitor["monitor/total_clip_ratio"] = total_clip
|
| 1110 |
-
if total_clip > 0.01 and step > 10:
|
| 1111 |
-
print(f"✓ Step {step}: clip_ratio={total_clip:.3f} — policy is updating")
|
| 1112 |
-
if monitor and wandb.run:
|
| 1113 |
-
wandb.log(monitor, step=step)
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
FastLanguageModel.for_training(model)
|
| 1117 |
-
|
| 1118 |
-
grpo_config = GRPOConfig(
|
| 1119 |
-
output_dir=str(CHECKPOINT_DIR),
|
| 1120 |
-
num_generations=NUM_GENERATIONS,
|
| 1121 |
-
scale_rewards=SCALE_REWARDS,
|
| 1122 |
-
max_completion_length=MAX_COMPLETION_LENGTH,
|
| 1123 |
-
temperature=TEMPERATURE,
|
| 1124 |
-
max_steps=MAX_STEPS,
|
| 1125 |
-
num_train_epochs=NUM_EPOCHS,
|
| 1126 |
-
per_device_train_batch_size=BATCH_SIZE,
|
| 1127 |
-
gradient_accumulation_steps=GRAD_ACCUM,
|
| 1128 |
-
learning_rate=LEARNING_RATE,
|
| 1129 |
-
warmup_ratio=0.1,
|
| 1130 |
-
lr_scheduler_type="cosine",
|
| 1131 |
-
fp16=False,
|
| 1132 |
-
bf16=True,
|
| 1133 |
-
logging_steps=1,
|
| 1134 |
-
logging_first_step=True,
|
| 1135 |
-
disable_tqdm=True,
|
| 1136 |
-
save_steps=SAVE_STEPS,
|
| 1137 |
-
save_total_limit=SAVE_TOTAL_LIMIT,
|
| 1138 |
-
save_only_model=True,
|
| 1139 |
-
eval_steps=EVAL_STEPS,
|
| 1140 |
-
report_to="wandb",
|
| 1141 |
-
max_prompt_length=MAX_SEQ_LENGTH - MAX_COMPLETION_LENGTH,
|
| 1142 |
-
seed=42,
|
| 1143 |
-
remove_unused_columns=False,
|
| 1144 |
-
**({"use_vllm": True, "vllm_mode": "colocate",
|
| 1145 |
-
"vllm_enable_sleep_mode": True} if USE_VLLM else {}),
|
| 1146 |
-
)
|
| 1147 |
-
|
| 1148 |
-
eval_cb = EvalRewardCallback(eval_records=list(eval_dataset), reward_fn=commerce_reward_fn)
|
| 1149 |
-
entropy_cb = EntropyMonitorCallback()
|
| 1150 |
-
|
| 1151 |
-
TrainerClass = GRPOTrainer if USE_VLLM else UnslothGRPOTrainer
|
| 1152 |
-
trainer = TrainerClass(
|
| 1153 |
-
model=model,
|
| 1154 |
-
reward_funcs=commerce_reward_fn,
|
| 1155 |
-
args=grpo_config,
|
| 1156 |
-
train_dataset=train_dataset,
|
| 1157 |
-
processing_class=tokenizer,
|
| 1158 |
-
callbacks=[eval_cb, entropy_cb],
|
| 1159 |
-
)
|
| 1160 |
-
|
| 1161 |
-
print(f"{'='*70}")
|
| 1162 |
-
print(f"GRPO v3 Training — Ready to Launch")
|
| 1163 |
-
print(f"{'='*70}")
|
| 1164 |
-
print(f" Trainer: {TrainerClass.__name__}")
|
| 1165 |
-
print(f" Max steps: {MAX_STEPS}")
|
| 1166 |
-
print(f" Temperature: {TEMPERATURE} (v2 was 0.8)")
|
| 1167 |
-
print(f" Completion: {MAX_COMPLETION_LENGTH} tokens (v2 was 2048)")
|
| 1168 |
-
print(f" Generations: {NUM_GENERATIONS} per prompt (v2 was 8)")
|
| 1169 |
-
print(f" Learning rate: {LEARNING_RATE} (v2 was 5e-7)")
|
| 1170 |
-
print(f" Save every: {SAVE_STEPS} steps (keep {SAVE_TOTAL_LIMIT})")
|
| 1171 |
-
print(f" Eval every: {EVAL_STEPS} steps ({EVAL_MAX_SAMPLES} samples × {EVAL_MAX_TOKENS} tok)")
|
| 1172 |
-
print(f" Patience: {EARLY_STOPPING_PATIENCE} evals ({EARLY_STOPPING_PATIENCE * EVAL_STEPS} steps)")
|
| 1173 |
-
print(f" Resume: {resume_from is not None}")
|
| 1174 |
-
print(f"{'='*70}")
|
| 1175 |
-
|
| 1176 |
-
t_start = time.time()
|
| 1177 |
-
result = trainer.train(resume_from_checkpoint=resume_from)
|
| 1178 |
-
elapsed = time.time() - t_start
|
| 1179 |
-
|
| 1180 |
-
wandb.log({
|
| 1181 |
-
"train/final_loss": result.training_loss,
|
| 1182 |
-
"train/duration_hours": elapsed / 3600,
|
| 1183 |
-
"train/total_steps": result.global_step,
|
| 1184 |
-
"eval/best_reward_final": eval_cb.best_reward,
|
| 1185 |
-
})
|
| 1186 |
-
wandb.finish()
|
| 1187 |
-
|
| 1188 |
-
print(f"\n{'='*70}")
|
| 1189 |
-
print(f"GRPO v3 Training Complete")
|
| 1190 |
-
print(f" Loss: {result.training_loss:.6f}")
|
| 1191 |
-
print(f" Steps: {result.global_step}")
|
| 1192 |
-
print(f" Duration: {elapsed/3600:.1f}h")
|
| 1193 |
-
print(f" Best eval R: {eval_cb.best_reward:.4f}")
|
| 1194 |
-
print(f" Trainer: {TrainerClass.__name__}")
|
| 1195 |
-
print(f"{'='*70}")
|
| 1196 |
-
```
|
| 1197 |
-
|
| 1198 |
-
---
|
| 1199 |
-
|
| 1200 |
-
## Cell 12: Save Adapter
|
| 1201 |
-
|
| 1202 |
-
```python
|
| 1203 |
-
GRPO_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
|
| 1204 |
-
model.save_pretrained(str(GRPO_ADAPTER_DIR))
|
| 1205 |
-
tokenizer.save_pretrained(str(GRPO_ADAPTER_DIR))
|
| 1206 |
-
|
| 1207 |
-
summary = {
|
| 1208 |
-
"model_id": MODEL_ID,
|
| 1209 |
-
"sft_adapter": str(SFT_ADAPTER_DIR),
|
| 1210 |
-
"method": "GRPO",
|
| 1211 |
-
"version": "v3",
|
| 1212 |
-
"train_loss": result.training_loss,
|
| 1213 |
-
"best_eval_reward": eval_cb.best_reward,
|
| 1214 |
-
"num_prompts": len(train_dataset),
|
| 1215 |
-
"num_generations": NUM_GENERATIONS,
|
| 1216 |
-
"scale_rewards": SCALE_REWARDS,
|
| 1217 |
-
"temperature": TEMPERATURE,
|
| 1218 |
-
"learning_rate": LEARNING_RATE,
|
| 1219 |
-
"beta": BETA,
|
| 1220 |
-
"max_completion_length": MAX_COMPLETION_LENGTH,
|
| 1221 |
-
"max_steps": MAX_STEPS,
|
| 1222 |
-
"actual_steps": result.global_step,
|
| 1223 |
-
"epochs": NUM_EPOCHS,
|
| 1224 |
-
"max_seq_length": MAX_SEQ_LENGTH,
|
| 1225 |
-
"duration_seconds": round(elapsed),
|
| 1226 |
-
"gpu": "L4",
|
| 1227 |
-
"platform": "vertex-ai-workbench",
|
| 1228 |
-
"v3_fixes": [
|
| 1229 |
-
"temperature=1.0 (Skywork-OR1)",
|
| 1230 |
-
"max_completion_length=4096 (Dr. GRPO)",
|
| 1231 |
-
"learning_rate=2e-6 (4x v2)",
|
| 1232 |
-
"beta=0.0 (Dr. GRPO)",
|
| 1233 |
-
"staged rewards (Reasoning-SQL)",
|
| 1234 |
-
"zero-advantage noise (Skywork-OR1)",
|
| 1235 |
-
"entropy monitoring callback",
|
| 1236 |
-
],
|
| 1237 |
-
}
|
| 1238 |
-
with open(GRPO_ADAPTER_DIR / "training_summary.json", "w") as f:
|
| 1239 |
-
json.dump(summary, f, indent=2)
|
| 1240 |
-
|
| 1241 |
-
print(f"✓ Adapter saved to {GRPO_ADAPTER_DIR}")
|
| 1242 |
-
print(f" Files: {[f.name for f in GRPO_ADAPTER_DIR.iterdir() if f.is_file()]}")
|
| 1243 |
-
```
|
| 1244 |
-
|
| 1245 |
-
---
|
| 1246 |
-
|
| 1247 |
-
## Cell 13: Validation
|
| 1248 |
-
|
| 1249 |
-
```python
|
| 1250 |
-
FastLanguageModel.for_inference(model)
|
| 1251 |
-
|
| 1252 |
-
system_msg = {"role": "system", "content": SYSTEM_PT}
|
| 1253 |
-
|
| 1254 |
-
test_prompts = [
|
| 1255 |
-
{"role": "user", "content": (
|
| 1256 |
-
"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
|
| 1257 |
-
"nota=2/5 | status=delivered\ntítulo: decepcionado\n"
|
| 1258 |
-
"texto: Produto veio com defeito e o vendedor não respondeu.\n\n"
|
| 1259 |
-
"Retorne um objeto JSON com exatamente estas chaves:\n"
|
| 1260 |
-
"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
|
| 1261 |
-
"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
|
| 1262 |
-
)},
|
| 1263 |
-
{"role": "user", "content": (
|
| 1264 |
-
"Analise esta avaliação de e-commerce brasileiro e extraia dados estruturados.\n\n"
|
| 1265 |
-
"nota=5/5 | status=delivered\ntítulo: adorei!\n"
|
| 1266 |
-
"texto: Entrega rápida e produto exatamente como descrito. Recomendo!\n\n"
|
| 1267 |
-
"Retorne um objeto JSON com exatamente estas chaves:\n"
|
| 1268 |
-
"sentiment, sentiment_score, churn_risk, delivery_issue, product_issue, "
|
| 1269 |
-
"seller_issue, main_complaint, complaint_category, repeat_intent, would_recommend"
|
| 1270 |
-
)},
|
| 1271 |
-
{"role": "user", "content": "Quais são as categorias de reclamação mais frequentes e como afetam a nota média?"},
|
| 1272 |
-
{"role": "user", "content": "Analise a retenção de clientes afetados por product_quality."},
|
| 1273 |
-
{"role": "user", "content": (
|
| 1274 |
-
"Perfil do cliente:\n- Estado: MG\n- Valor do pedido: R$150\n"
|
| 1275 |
-
"- Reclamação: produto com defeito\n- Nota: 1.0/5\n\n"
|
| 1276 |
-
"Este cliente deve receber uma notificação de reengajamento?"
|
| 1277 |
-
)},
|
| 1278 |
-
{"role": "user", "content": "Compare a satisfação de clientes em SP vs RJ."},
|
| 1279 |
-
{"role": "user", "content": (
|
| 1280 |
-
"Crie uma notificação push de reengajamento para um cliente em SP "
|
| 1281 |
-
"que reclamou de atraso na entrega. Nota: 2/5."
|
| 1282 |
-
)},
|
| 1283 |
-
]
|
| 1284 |
-
|
| 1285 |
-
print("=" * 70)
|
| 1286 |
-
print("GRPO v3 Validation")
|
| 1287 |
-
print("=" * 70)
|
| 1288 |
-
|
| 1289 |
-
v3_rewards = []
|
| 1290 |
-
for i, prompt in enumerate(test_prompts):
|
| 1291 |
-
messages = [system_msg, prompt]
|
| 1292 |
-
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 1293 |
-
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
| 1294 |
-
|
| 1295 |
-
outputs = model.generate(**inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=0.1, do_sample=True)
|
| 1296 |
-
gen_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
|
| 1297 |
-
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 1298 |
-
|
| 1299 |
-
reward = commerce_reward_fn([response], [text])[0]
|
| 1300 |
-
v3_rewards.append(reward)
|
| 1301 |
-
answer = strip_think(response)
|
| 1302 |
-
task = _classify_task_type(prompt["content"])
|
| 1303 |
-
hit_ceiling = gen_tokens >= MAX_COMPLETION_LENGTH
|
| 1304 |
-
|
| 1305 |
-
print(f"\n--- Sample {i+1} [{task}] (reward={reward:.2f}, tokens={gen_tokens}, ceiling={'HIT' if hit_ceiling else 'ok'}) ---")
|
| 1306 |
-
print(f"Prompt: {prompt['content'][:80]}...")
|
| 1307 |
-
print(f"Answer: {answer[:400]}")
|
| 1308 |
-
|
| 1309 |
-
print(f"\n{'='*70}")
|
| 1310 |
-
print(f"v3 Validation Summary")
|
| 1311 |
-
print(f"{'='*70}")
|
| 1312 |
-
print(f" Mean reward: {sum(v3_rewards)/len(v3_rewards):.3f}")
|
| 1313 |
-
print(f" Min: {min(v3_rewards):.3f}")
|
| 1314 |
-
print(f" Max: {max(v3_rewards):.3f}")
|
| 1315 |
-
print()
|
| 1316 |
-
print(f" Comparison to baselines:")
|
| 1317 |
-
print(f" SFT calibration (Cell 7): mean=0.38")
|
| 1318 |
-
print(f" GRPO v2 validation: mean=0.54")
|
| 1319 |
-
print(f" GRPO v3 validation: mean={sum(v3_rewards)/len(v3_rewards):.3f}")
|
| 1320 |
-
v3_vs_v2 = (sum(v3_rewards)/len(v3_rewards) - 0.54) / 0.54 * 100
|
| 1321 |
-
print(f" v3 vs v2: {v3_vs_v2:+.1f}%")
|
| 1322 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|