File size: 47,675 Bytes
6c7b1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aa00ff
6c7b1ca
 
 
 
 
 
 
 
 
 
 
631e559
6c7b1ca
 
 
 
 
 
 
 
 
 
 
521e1d8
6c7b1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631e559
6c7b1ca
 
 
 
 
 
 
 
 
 
 
631e559
6c7b1ca
 
 
 
 
 
 
 
 
 
 
631e559
6c7b1ca
 
 
 
 
 
 
 
 
 
 
631e559
6c7b1ca
 
 
 
 
 
 
 
 
 
 
5aa00ff
6c7b1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# Tucano2 Commerce β€” GRPO Training V4 (Instruct-Only, 0.5B)\n\n**Reference:** `docs/ADR-002-v4-instruct.md` β€” Single-Model Validation on 0.5B-Instruct\n\n**V4 key changes over V3:**\n\n| Aspect | V3 | V4 | Why |\n|--------|----|----|-----|\n| Model | 3.7B-Think | **0.5B-Instruct** | Think scored 14.41 NPM vs Instruct's 26.08; clip_ratio=0 on all V3 steps |\n| `<think>` overhead | 2000-3000 tok | **None** | Instruct template has no `<think>` injection |\n| Num generations | 4 | **16** | 0.5B is 8Γ— lighter; more G = more reward variance |\n| Completion length | 4096 | **512** | No think overhead; extraction ~100, SQL ~200, insights ~300 tok |\n| Sequence length | 8192 | **2048** | Shorter completions need less context |\n| Reward design | Staged (formatβ†’partialβ†’task) | **Continuous** | Simpler; Instruct doesn't need think bonus |\n| Probe gate | 3 steps, no hard gate | **10 steps, clip_ratio β‰₯ 3/10** | V3's missed signal β€” hard gate prevents wasted compute |\n| generation_config | Override temp, use_cache | **Override 6 fields** | rep_penalty=1.2 and top_k=50 also destroy GRPO |\n\n**Prerequisites:**\n- Upload `data/pairs/train.jsonl` to `./data/pairs/`\n- Hardware: L4 (24GB), PyTorch kernel, bf16 supported"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 1: Dependencies\n\nRestart kernel first (Kernel β†’ Restart), then run:"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Cell 1 β€” Clean install\n# Run after kernel restart\n\n!pip install \"unsloth\"\n!pip install \"trl==0.24.0\" --no-deps\n!pip install \"rich\" \"wandb\""
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 2: GPU + Unsloth Verification\n\n**Gate:** CUDA available, bf16=True, VRAM > 20GB, TRL 0.24.0."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import torch\n\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\nprint(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n\nfrom unsloth import FastLanguageModel\nprint(f\"\\nβœ“ Unsloth loaded\")\n\nimport trl\nassert trl.__version__ == \"0.24.0\", f\"Expected TRL 0.24.0, got {trl.__version__}\"\nprint(f\"βœ“ TRL {trl.__version__}\")\n\nimport transformers\nprint(f\"βœ“ Transformers {transformers.__version__}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 3: Config Constants\n\nAll hyperparameters from ADR-002 Β§9. Every value is annotated with its rationale."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import os\nimport json\nimport re\nimport time\nimport random\nimport gc\nimport warnings\nfrom pathlib import Path\n\n# ── Suppress noisy deprecation warnings from Transformers 5.5.0 ──────────────\nwarnings.filterwarnings(\"ignore\", message=\".*AttentionMaskConverter.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*Passing `generation_config` together with.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*max_new_tokens.*max_length.*\")\nwarnings.filterwarnings(\"ignore\", category=FutureWarning)\n\n# ── Disable Unsloth kernel recompilation ─────────────────────────────────────\nos.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\n# ── Model ────────────────────────────────────────────────────────────────────\nMODEL_ID       = \"Polygl0t/Tucano2-qwen-0.5B-Instruct\"\nMAX_SEQ_LENGTH = 2048   # model supports 4096, but 2048 is plenty for Instruct (no <think> overhead)\nMODELS_DIR     = Path(\"/home/jupyter/tucano2/models\")\nADAPTER_DIR    = MODELS_DIR / \"tucano2-0.5B-instruct-grpo-v4\"\nCHECKPOINT_DIR = ADAPTER_DIR / \"checkpoints\"\n\n# ── Data ─────────────────────────────────────────────────────────────────────\nDATA_DIR       = Path(\"/home/jupyter/tucano2/data\")\nTRAIN_FILE     = DATA_DIR / \"pairs\" / \"train.jsonl\"\nEVAL_SPLIT     = 0.10   # 10% held out for eval\n\n# ── GRPO Hyperparameters ─────────────────────────────────────────────────────\nNUM_GENERATIONS        = 16     # 0.5B + short completions = VRAM allows G=16\nMAX_COMPLETION_LENGTH  = 512    # Instruct: no <think> overhead. Extraction ~100, SQL ~200, insights ~300\nTEMPERATURE            = 1.0    # Skywork-OR1: Ο„=1.0 for exploration\nLEARNING_RATE          = 2e-6   # Dr. GRPO: 4Γ— V2's 5e-7 (clip_ratio=0 β†’ push harder)\nBETA                   = 0.0    # Dr. GRPO Β§3.2: Ξ²=0 optimal for rule-based rewards\nSCALE_REWARDS          = False  # Dr. GRPO: remove std normalization bias\nBATCH_SIZE             = 2      # per-device batch size\nGRAD_ACCUM             = 1      # effective batch = 2 * 1 = 2 prompts * 16 gen = 32 completions\nMAX_STEPS              = 200    # validation run\nSAVE_STEPS             = 20\nEVAL_STEPS             = 10\nEARLY_STOPPING_PATIENCE = 15\nEARLY_STOPPING_DELTA   = 0.005\n\n# ── LoRA ─────────────────────────────────────────────────────────────────────\nLORA_R     = 16\nLORA_ALPHA = 32\n\n# ── Monitoring ───────────────────────────────────────────────────────────────\nWANDB_PROJECT     = \"tucano2-commerce\"\nEVAL_MAX_SAMPLES  = 15    # eval callback samples\nEVAL_MAX_TOKENS   = 512   # match training completion length\n\n# ── Task Classification (inherited from V2/V3) ──────────────────────────────\nVALID_SENTIMENTS  = {\"positive\", \"negative\", \"neutral\"}\nVALID_CATEGORIES  = {\n    \"delivery_delay\", \"product_quality\", \"product_not_received\",\n    \"wrong_product\", \"seller_communication\", \"app_issue\",\n    \"price_value\", \"other\", \"none\",\n}\nVALID_CHURN  = {\"low\", \"medium\", \"high\"}\nVALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\nEXTRACTION_FIELDS = [\n    \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n    \"product_issue\", \"seller_issue\", \"main_complaint\",\n    \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n]\n\n# ── Verified Special Token IDs (from tokenizer_config.json) ─────────────────\n# These are constants β€” do NOT recompute via tokenizer.encode()\nTOKEN_ID_BOS       = 1      # <|im_start|>\nTOKEN_ID_EOS       = 2      # <|im_end|>\nTOKEN_ID_PAD       = 49109  # <|pad|>\nTOKEN_ID_THINK     = 49116  # <think>\nTOKEN_ID_THINK_END = 49117  # </think>\n\n# ══════════════════════════════════════════════════════════════════════════════\n# TASK-AWARE SYSTEM PROMPTS (inherited from V3)\n# Research basis:\n#   - OptimalThinkingBench (2508.13141): task-specific instructions improve accuracy\n#   - V3 calibration: extraction prompt \"Retorne APENAS um objeto JSON\" teaches format\n#   - Cell 7 evidence: without this, model adds explanation text after JSON\n# ══════════════════════════════════════════════════════════════════════════════\n\nSYSTEM_EXTRACTION = (\n    \"VocΓͺ Γ© um motor de extraΓ§Γ£o de dados de e-commerce brasileiro. \"\n    \"Retorne APENAS um objeto JSON vΓ‘lido, sem nenhum texto antes ou depois. \"\n    \"NΓƒO USE blocos de cΓ³digo markdown (```json). \"\n    \"O primeiro caractere da sua resposta deve ser { e o ΓΊltimo deve ser }. \"\n    \"Campos nΓ£o mencionados na avaliaΓ§Γ£o devem ser null β€” nunca invente valores. \"\n    \"Sem explicaΓ§Γ£o. Sem comentΓ‘rios.\"\n)\n\nSYSTEM_SQL = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para consultas e anΓ‘lises de dados: apresente a resposta de forma direta \"\n    \"com nΓΊmeros e dados concretos. Seja conciso.\"\n)\n\nSYSTEM_INSIGHTS = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para anΓ‘lises estratΓ©gicas: raciocine de forma estruturada e concisa, \"\n    \"focando nos pontos principais e recomendaΓ§Γ΅es acionΓ‘veis.\"\n)\n\nSYSTEM_PUSH = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para notificaΓ§Γ΅es push: seja direto e criativo. \"\n    \"A notificaΓ§Γ£o deve ter no mΓ‘ximo 120 caracteres. \"\n    \"Responda diretamente.\"\n)\n\nSYSTEM_PT = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\"\n)\n\ndef get_system_prompt(task_type: str) -> str:\n    return {\n        \"extraction\": SYSTEM_EXTRACTION,\n        \"sql_qa\": SYSTEM_SQL,\n        \"insights\": SYSTEM_INSIGHTS,\n        \"push\": SYSTEM_PUSH,\n    }.get(task_type, SYSTEM_PT)\n\ndef inject_task_system_prompt(msgs, task_type):\n    \"\"\"Replace generic system prompt with task-specific one.\"\"\"\n    new_msgs = []\n    system_prompt = get_system_prompt(task_type)\n    has_system = False\n    for m in msgs:\n        if m[\"role\"] == \"system\":\n            new_msgs.append({\"role\": \"system\", \"content\": system_prompt})\n            has_system = True\n        else:\n            new_msgs.append(m)\n    if not has_system:\n        new_msgs.insert(0, {\"role\": \"system\", \"content\": system_prompt})\n    return new_msgs\n\nprint(\"βœ“ Task-aware system prompts defined\")\n\nprint(\"βœ“ Config loaded\")\nprint(f\"  Model: {MODEL_ID}\")\nprint(f\"  G={NUM_GENERATIONS}, max_comp={MAX_COMPLETION_LENGTH}, temp={TEMPERATURE}\")\nprint(f\"  LR={LEARNING_RATE}, Ξ²={BETA}, scale_rewards={SCALE_REWARDS}\")\nprint(f\"  LoRA r={LORA_R}, Ξ±={LORA_ALPHA}\")\nprint(f\"  Max steps: {MAX_STEPS}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 4: Load Model + Apply Critical Overrides\n\n**Gate:** Model loaded, `use_cache=True`, `repetition_penalty=1.0`, `temperature=1.0`.\n\n**Bugs fixed here (ADR-002 Β§1.2):**\n- Bug 2a: `use_cache: false` in config.json β†’ override to True\n- Bug 2b: `repetition_penalty: 1.2` in generation_config β†’ override to 1.0\n- Bug 2c: `temperature: 0.1` in generation_config β†’ override to 1.0\n- Bug 2d: Tied embeddings check after LoRA patching"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name=MODEL_ID,\n    max_seq_length=MAX_SEQ_LENGTH,\n    load_in_4bit=True,\n    dtype=None,  # auto-detect (matches V3 which ran 171 steps successfully)\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# LoRA ADAPTER β€” ADR-002 Β§9: r=16, Ξ±=32\n# Using standard PEFT instead of Unsloth's get_peft_model() to avoid fused\n# LoRA QKV/O/MLP kernels that have fp16/bf16 dtype mismatch bug (unsloth #4891).\n# V3 avoided this because it loaded a pre-existing adapter (0 QKV/O/MLP patches).\n# V4 applies fresh LoRA β†’ Unsloth patches all 28 layers β†’ crash in matmul_lora.\n# Standard PEFT uses PyTorch native matmul which handles bf16 correctly.\n# ═══════════════════════════════════════════════════════════════════════════════\nfrom peft import LoraConfig, get_peft_model\n\nlora_config = LoraConfig(\n    r=LORA_R,\n    lora_alpha=LORA_ALPHA,\n    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n                    \"gate_proj\", \"up_proj\", \"down_proj\"],\n    lora_dropout=0,\n    bias=\"none\",\n    task_type=\"CAUSAL_LM\",\n)\nmodel = get_peft_model(model, lora_config)\nmodel.print_trainable_parameters()\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# CRITICAL OVERRIDES β€” generation_config ships with values that destroy GRPO\n# Source: Polygl0t/Tucano2-qwen-0.5B-Instruct/generation_config.json\n#   temperature: 0.1       β†’ override to 1.0\n#   repetition_penalty: 1.2 β†’ override to 1.0\n#   use_cache: false        β†’ override to true\n# ═══════════════════════════════════════════════════════════════════════════════\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0   # CRITICAL: 1.2 suppresses diversity\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0    # disable top-k β€” let temperature control diversity\nmodel.generation_config.top_p = 1.0  # disable top-p\nmodel.generation_config.max_length = None  # remove conflict with max_new_tokens\n\n# Pad token\nif tokenizer.pad_token is None:\n    tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"βœ“ Model loaded on {model.device}\")\nprint(f\"  use_cache: {model.config.use_cache}\")\nprint(f\"  temperature: {model.generation_config.temperature}\")\nprint(f\"  repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\"  top_k: {model.generation_config.top_k}\")\nprint(f\"  Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# TIED EMBEDDINGS CHECK β€” ADR-002 Decision 4\n# Source: config.json has \"tie_word_embeddings\": true\n# After LoRA patching, verify lm_head and embed_tokens still share weights.\n# ═══════════════════════════════════════════════════════════════════════════════\n\ntry:\n    lm_ptr = model.base_model.model.lm_head.weight.data_ptr()\n    embed_ptr = model.base_model.model.model.embed_tokens.weight.data_ptr()\n    tied = lm_ptr == embed_ptr\n    print(f\"  Tied embeddings intact: {tied}\")\n    if not tied:\n        print(\"  ⚠️ WARNING: Tied embeddings broken after LoRA patching. May affect output head gradients.\")\nexcept AttributeError as e:\n    print(f\"  ⚠️ Could not check tied embeddings: {e}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 5: Token ID Verification\n\n**Gate:** All token IDs match. Single-token `<think>` (49116) and `</think>` (49117) confirmed."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Verify that the constants from Cell 3 match the actual tokenizer\n# Do NOT skip this cell β€” if IDs don't match, all reward functions break\n\ntok_tests = {\n    \"<|im_start|>\": TOKEN_ID_BOS,\n    \"<|im_end|>\":   TOKEN_ID_EOS,\n    \"<|pad|>\":      TOKEN_ID_PAD,\n    \"<think>\":      TOKEN_ID_THINK,\n    \"</think>\":     TOKEN_ID_THINK_END,\n}\n\nall_pass = True\nfor text, expected_id in tok_tests.items():\n    ids = tokenizer.encode(text, add_special_tokens=False)\n    actual_id = ids[0] if len(ids) == 1 else ids\n    match = (len(ids) == 1 and ids[0] == expected_id)\n    status = \"βœ“\" if match else \"βœ—\"\n    print(f\"  {status} '{text}' β†’ expected {expected_id}, got {actual_id}\")\n    if not match:\n        all_pass = False\n\nassert all_pass, \"Token ID mismatch detected. Update constants in Cell 3 before proceeding.\"\nprint(\"\\nβœ“ All token IDs verified\")\n\nassert tokenizer.eos_token_id == TOKEN_ID_EOS, f\"eos_token_id mismatch: {tokenizer.eos_token_id}\"\nprint(f\"βœ“ eos_token_id = {tokenizer.eos_token_id}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 6: KV Cache Diagnostic\n\n**Gate:** Ratio < 3Γ— β†’ KV cache OK. Ratio > 5Γ— β†’ BROKEN, abort."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Copied from V2 Cell 5b β€” verify KV cache is working\n\nFastLanguageModel.for_inference(model)\n\n_kv_msgs = [{\"role\": \"user\", \"content\": \"Qual a categoria de reclamaΓ§Γ£o mais frequente?\"}]\n_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n\n_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\nwith torch.no_grad():\n    for _step in range(50):\n        _t0 = time.time()\n        seq_len = _generated.shape[1]\n        if _past is None:\n            _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n        else:\n            _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n        _out = model(\n            input_ids=_generated[:, -1:] if _past else _generated,\n            position_ids=_position_ids,\n            attention_mask=torch.ones(1, seq_len, device=model.device),\n            past_key_values=_past,\n            use_cache=True,\n            return_dict=True,\n        )\n        _past = _out.past_key_values\n        _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n        _generated = torch.cat([_generated, _next], dim=1)\n        _token_times.append(time.time() - _t0)\n\n_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\nprint(f\"First 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\nprint(f\"Last  5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\nprint(f\"Ratio last/first: {_ratio:.1f}x\")\nassert _ratio < 5, f\"KV cache BROKEN (ratio {_ratio:.1f}Γ—). Check model.config.use_cache.\"\nprint(\"βœ“ KV cache working correctly\")\n\ndel _past, _generated, _kv_inputs, _token_times, _out\ngc.collect()\ntorch.cuda.empty_cache()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 7: Single Inference Test\n\n**Gate:** Response is coherent Portuguese. Check if `<think>` appears. Check if JSON structure present."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "FastLanguageModel.for_inference(model)\n\ntest_msgs = [\n    {\"role\": \"system\", \"content\": SYSTEM_EXTRACTION},\n    {\"role\": \"user\", \"content\": \"Analise esta avaliaΓ§Γ£o: 'Produto chegou quebrado, pΓ©ssima embalagem. Nunca mais compro aqui.' Retorne um objeto JSON com os campos: sentiment, sentiment_score, delivery_issue, complaint_category.\"},\n]\ntext = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)\ninputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n\nt0 = time.time()\noutputs = model.generate(\n    **inputs,\n    max_new_tokens=256,\n    temperature=0.1,       # low temp for deterministic eval\n    do_sample=True,\n    repetition_penalty=1.0,\n)\nelapsed = time.time() - t0\n\nresponse = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\nprint(f\"Generation time: {elapsed:.1f}s\")\nprint(f\"Response length: {len(response)} chars\")\nprint(f\"Contains <think>: {'<think>' in response}\")\nprint(f\"Contains JSON {{ }}: {'{' in response and '}' in response}\")\nprint(f\"\\n{'='*60}\")\nprint(response[:500])"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 8: Reward Functions\n\nComplete reward functions from ADR-002 Β§6. All functions call `strip_think()` defensively,\neven for the Instruct model (which shouldn't generate `<think>` but might spontaneously)."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "def strip_think(text: str) -> str:\n    \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n    return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n\n\ndef has_think_block(text: str) -> bool:\n    return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n\n\ndef _classify_task_type(prompt_text: str) -> str:\n    p = prompt_text.lower()\n    if \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n        return \"extraction\"\n    elif \"notificaΓ§Γ£o push\" in p or \"notificaΓ§Γ£o de reengajamento\" in p:\n        return \"push\"\n    elif \"perfil do cliente\" in p or \"retenΓ§Γ£o\" in p or \"anΓ‘lise\" in p or \"insight\" in p:\n        return \"insights\"\n    else:\n        return \"sql_qa\"\n\n\ndef _extract_json(text: str) -> dict | None:\n    \"\"\"Extract first JSON object from text. Returns parsed dict or None.\"\"\"\n    stripped = text.strip()\n    stripped = re.sub(r\"^```(?:json)?\\s*\", \"\", stripped)\n    stripped = re.sub(r\"\\s*```$\", \"\", stripped)\n    stripped = stripped.strip()\n    try:\n        return json.loads(stripped)\n    except (json.JSONDecodeError, TypeError):\n        pass\n    match = re.search(r\"\\{[^{}]*(?:\\{[^{}]*\\}[^{}]*)*\\}\", text, re.DOTALL)\n    if match:\n        try:\n            return json.loads(match.group())\n        except (json.JSONDecodeError, TypeError):\n            pass\n    return None\n\n\ndef reward_extraction(completion: str) -> float:\n    \"\"\"Continuous reward for extraction tasks (max 1.0).\"\"\"\n    answer = strip_think(completion)\n    data = _extract_json(answer)\n\n    if data is None:\n        if \"{\" in answer and \"}\" in answer:\n            return 0.05\n        return 0.0\n\n    if not isinstance(data, dict):\n        return 0.1  # valid JSON but not an object\n\n    score = 0.3  # valid JSON object\n\n    # Schema completeness (0.3 total)\n    present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n    score += 0.3 * (present / len(EXTRACTION_FIELDS))\n\n    # Value validity (0.4 total, split across checks)\n    checks_passed = 0\n    checks_total = 0\n\n    for field, validator in [\n        (\"sentiment\", lambda v: isinstance(v, str) and v in VALID_SENTIMENTS),\n        (\"complaint_category\", lambda v: isinstance(v, str) and v in VALID_CATEGORIES),\n        (\"churn_risk\", lambda v: isinstance(v, str) and v in VALID_CHURN),\n        (\"repeat_intent\", lambda v: isinstance(v, str) and v in VALID_REPEAT),\n        (\"sentiment_score\", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),\n    ]:\n        checks_total += 1\n        if field in data and validator(data[field]):\n            checks_passed += 1\n\n    for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n        checks_total += 1\n        if bool_field in data and isinstance(data[bool_field], bool):\n            checks_passed += 1\n\n    if checks_total > 0:\n        score += 0.4 * (checks_passed / checks_total)\n\n    return min(score, 1.0)\n\n\ndef reward_sql_qa(completion: str) -> float:\n    \"\"\"Continuous reward for SQL Q&A (max 1.0).\"\"\"\n    answer = strip_think(completion)\n    if not answer.strip():\n        return 0.0\n\n    score = 0.0\n\n    # Numerical content (more numbers = more specific answer)\n    numbers = re.findall(r\"\\d+(?:[.,]\\d+)?\", answer)\n    score += min(0.4, 0.1 * len(numbers))\n\n    # Length: 50-500 chars optimal\n    length = len(answer)\n    if 50 <= length <= 500:\n        score += 0.3\n    elif length > 0:\n        score += 0.3 * max(0, 1 - abs(length - 275) / 275)\n\n    # Portuguese business vocabulary\n    pt_business = [\"pedidos\", \"clientes\", \"mΓ©dia\", \"total\", \"taxa\", \"vendas\",\n                   \"produtos\", \"perΓ­odo\", \"categoria\", \"regiΓ£o\", \"faturamento\"]\n    pt_matches = sum(1 for w in pt_business if w in answer.lower())\n    score += min(0.3, 0.06 * pt_matches)\n\n    return min(score, 1.0)\n\n\ndef reward_insights(completion: str) -> float:\n    \"\"\"Continuous reward for insights (max 1.0).\"\"\"\n    answer = strip_think(completion)\n    if not answer.strip():\n        return 0.0\n\n    score = 0.0\n\n    # Actionable language\n    action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n                    \"priorizar\", \"investir\", \"otimizar\", \"estratΓ©gi\", \"aΓ§Γ£o\"]\n    matches = sum(1 for w in action_words if w in answer.lower())\n    score += min(0.4, 0.08 * matches)\n\n    # Length: 100-800 chars optimal\n    length = len(answer)\n    if 100 <= length <= 800:\n        score += 0.3\n    elif length > 0:\n        score += 0.3 * max(0, 1 - abs(length - 450) / 450)\n\n    # Structure: bullet points, numbered lists, headers\n    structure_marks = len(re.findall(r\"^[-β€’*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n    score += min(0.2, 0.04 * structure_marks)\n\n    # Portuguese coherence marker\n    if any(w in answer.lower() for w in [\"cliente\", \"produto\", \"serviΓ§o\", \"empresa\"]):\n        score += 0.1\n\n    return min(score, 1.0)\n\n\ndef reward_push(completion: str) -> float:\n    \"\"\"Continuous reward for push notifications (max 1.0).\"\"\"\n    answer = strip_think(completion).strip()\n    if not answer:\n        return 0.0\n\n    # Length: ≀120 chars gets full credit\n    length = len(answer)\n    if length <= 120:\n        length_score = 0.5\n    else:\n        length_score = 0.5 * max(0, 1 - (length - 120) / 120)\n\n    # Portuguese content\n    pt_markers = re.findall(r\"[ãçéΓͺΓ³ΓΊΓ’Γ΅]|vocΓͺ|para|como|seu|sua|oferta|desconto|produto\",\n                            answer, re.IGNORECASE)\n    lang_score = min(0.3, 0.03 * len(pt_markers))\n\n    # Non-generic (penalize very generic phrases)\n    generic = [\"olΓ‘\", \"obrigado pela compra\", \"agradecemos\"]\n    is_generic = any(g in answer.lower() for g in generic)\n    creativity_score = 0.0 if is_generic else 0.2\n\n    return min(length_score + lang_score + creativity_score, 1.0)\n\n\ndef commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n    \"\"\"Master reward function: dispatches by task type.\"\"\"\n    rewards = []\n    for completion, prompt in zip(completions, prompts):\n        if isinstance(completion, list):\n            comp_text = completion[-1][\"content\"] if completion else \"\"\n        else:\n            comp_text = str(completion)\n\n        if isinstance(prompt, list):\n            prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n        else:\n            prompt_text = str(prompt)\n\n        task = _classify_task_type(prompt_text)\n\n        if task == \"extraction\":\n            rewards.append(reward_extraction(comp_text))\n        elif task == \"sql_qa\":\n            rewards.append(reward_sql_qa(comp_text))\n        elif task == \"insights\":\n            rewards.append(reward_insights(comp_text))\n        elif task == \"push\":\n            rewards.append(reward_push(comp_text))\n        else:\n            r = 0.2 if comp_text.strip() else 0.0\n            rewards.append(r)\n\n    return rewards\n\n\nprint(\"βœ“ Reward functions defined\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 9: Reward Calibration\n\n**Gate:** Mean reward < 0.90 (if already ~1.0, reward is too easy). Variance > 0. Document whether `<think>` appeared."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Load data, classify by task type, run calibration on 8 diverse samples\n\nby_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\nwith open(TRAIN_FILE) as f:\n    for line in f:\n        row = json.loads(line)\n        convs = row[\"conversations\"]\n        prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n        if not prompt_msgs:\n            continue\n        user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n        task = _classify_task_type(user_text)\n        by_type[task].append(prompt_msgs)\n\nprint(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n\n# Pick 2 samples per task type = 8 total\nrng = random.Random(42)\ncal_samples = []\nfor task_type in by_type:\n    pool = by_type[task_type]\n    if len(pool) >= 2:\n        cal_samples.extend(rng.sample(pool, 2))\n    elif pool:\n        cal_samples.extend(pool)\n\nFastLanguageModel.for_inference(model)\nprint(f\"\\nReward calibration ({len(cal_samples)} samples):\")\nprint(\"-\" * 60)\n\ncal_rewards = []\nfor i, msgs in enumerate(cal_samples):\n    # Inject task-aware system prompt\n    user_text_cal = \" \".join(m.get(\"content\", \"\") for m in msgs if m[\"role\"] == \"user\")\n    task_cal = _classify_task_type(user_text_cal)\n    msgs = inject_task_system_prompt(msgs, task_cal)\n    text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n    inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n    outputs = model.generate(\n        **inputs,\n        max_new_tokens=MAX_COMPLETION_LENGTH,\n        temperature=0.7,\n        do_sample=True,\n        repetition_penalty=1.0,\n    )\n    response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n    r = commerce_reward_fn([response], [text])[0]\n    cal_rewards.append(r)\n    task = _classify_task_type(\" \".join(m.get(\"content\", \"\") for m in msgs if m[\"role\"] == \"user\"))\n    has_think = \"<think>\" in response\n    answer_preview = strip_think(response)[:100]\n    print(f\"  Sample {i+1} [{task:12s}]: reward={r:.2f} | has_think={has_think} | {answer_preview}\")\n\nprint(f\"\\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}\")\nprint(f\"Reward variance > 0: {len(set(f'{r:.4f}' for r in cal_rewards)) > 1}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 10: Dataset Preparation\n\n**Gate:** Train has ~1,650 prompts, eval has ~180. All 4 task types present in both."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from datasets import Dataset\n\ndef prepare_datasets(train_file, eval_ratio=EVAL_SPLIT, seed=42):\n    rng = random.Random(seed)\n\n    all_records = []\n    with open(train_file) as f:\n        for line in f:\n            row = json.loads(line)\n            convs = row[\"conversations\"]\n            prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n            if prompt_msgs:\n                all_records.append(prompt_msgs)\n\n    rng.shuffle(all_records)\n    n_eval = max(1, int(len(all_records) * eval_ratio))\n    eval_records = all_records[:n_eval]\n    train_records = all_records[n_eval:]\n\n    # ── Inject task-aware system prompts (V3 prompt engineering) ─────────────\n    for i, msgs in enumerate(train_records):\n        user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n        task = _classify_task_type(user_text)\n        train_records[i] = inject_task_system_prompt(msgs, task)\n    for i, msgs in enumerate(eval_records):\n        user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n        task = _classify_task_type(user_text)\n        eval_records[i] = inject_task_system_prompt(msgs, task)\n    print(\"  βœ“ Task-aware system prompts injected\")\n\n    # Log task distribution\n    for label, records in [(\"train\", train_records), (\"eval\", eval_records)]:\n        dist = {}\n        for msgs in records:\n            user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n            task = _classify_task_type(user_text)\n            dist[task] = dist.get(task, 0) + 1\n        print(f\"  {label}: {len(records)} prompts β€” {dist}\")\n\n    train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n    eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n    return train_ds, eval_ds\n\ntrain_dataset, eval_dataset = prepare_datasets(TRAIN_FILE)\nprint(f\"\\nβœ“ Datasets: train={len(train_dataset)}, eval={len(eval_dataset)}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 11: Smoke Test (1 Step)\n\n**Gate:** No OOM. Peak VRAM < 20GB. Step time < 180s."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from trl import GRPOConfig, GRPOTrainer\n\nFastLanguageModel.for_training(model)\n\nsmoke_config = GRPOConfig(\n    output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n    num_generations=NUM_GENERATIONS,\n    scale_rewards=SCALE_REWARDS,\n    max_completion_length=MAX_COMPLETION_LENGTH,\n    max_steps=1,\n    temperature=TEMPERATURE,\n    beta=BETA,\n    per_device_train_batch_size=BATCH_SIZE,\n    gradient_accumulation_steps=1,\n    learning_rate=LEARNING_RATE,\n    fp16=False,\n    bf16=True,\n    logging_steps=1,\n    save_steps=999,\n    report_to=\"none\",\n    max_prompt_length=MAX_SEQ_LENGTH // 2,\n    seed=42,\n    remove_unused_columns=False,\n)\n\n\nclass UnslothGRPOTrainer(GRPOTrainer):\n    \"\"\"Wraps generation with Unsloth for_inference()/for_training().\"\"\"\n    def _generate(self, prompts, images):\n        FastLanguageModel.for_inference(self.model)\n        try:\n            result = super()._generate(prompts, images)\n        finally:\n            FastLanguageModel.for_training(self.model)\n        return result\n\n\nsmoke_trainer = UnslothGRPOTrainer(\n    model=model,\n    reward_funcs=commerce_reward_fn,\n    args=smoke_config,\n    train_dataset=train_dataset,\n    processing_class=tokenizer,\n)\n\nt0 = time.time()\nsmoke_trainer.train()\nstep_time = time.time() - t0\n\npeak_vram = torch.cuda.max_memory_allocated() / 1e9\nprint(f\"\\nβœ“ Smoke test passed!\")\nprint(f\"  Step time: {step_time:.0f}s\")\nprint(f\"  Peak VRAM: {peak_vram:.1f}GB / {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB\")\nprint(f\"  Estimated full run ({MAX_STEPS} steps): {step_time * MAX_STEPS / 3600:.1f}h\")\n\ndel smoke_trainer\ngc.collect(); torch.cuda.empty_cache()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 12: Probe Run (10 Steps) β€” THE CRITICAL GATE\n\n**Gate:** `nonzero_clips >= 3`. If this fails, go to ADR-002 Section 8 (Fallback Plan).\n\nThis is the gate V3 missed. If clip_ratio = 0 on all steps, the policy is not learning.\nDo NOT proceed to full training without passing this gate."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "FastLanguageModel.for_training(model)\n\nprobe_config = GRPOConfig(\n    output_dir=str(CHECKPOINT_DIR / \"probe\"),\n    num_generations=NUM_GENERATIONS,\n    scale_rewards=SCALE_REWARDS,\n    max_completion_length=MAX_COMPLETION_LENGTH,\n    max_steps=10,\n    temperature=TEMPERATURE,\n    beta=BETA,\n    num_train_epochs=1,\n    per_device_train_batch_size=BATCH_SIZE,\n    gradient_accumulation_steps=GRAD_ACCUM,\n    learning_rate=LEARNING_RATE,\n    warmup_ratio=0.1,\n    lr_scheduler_type=\"cosine\",\n    fp16=False,\n    bf16=True,\n    logging_steps=1,\n    save_steps=999,\n    report_to=\"none\",\n    max_prompt_length=MAX_SEQ_LENGTH // 2,\n    seed=42,\n    remove_unused_columns=False,\n)\n\nprobe_trainer = UnslothGRPOTrainer(\n    model=model,\n    reward_funcs=commerce_reward_fn,\n    args=probe_config,\n    train_dataset=train_dataset,\n    processing_class=tokenizer,\n)\n\nt0 = time.time()\nresult = probe_trainer.train()\nelapsed = time.time() - t0\n\n# ══════════════════════════════════════════════════════════════════════════════\n# CRITICAL GATE: clip_ratio > 0 on at least 3 of 10 steps\n# If this fails, STOP. See Fallback Plan (Section 8 of ADR-002).\n# ══════════════════════════════════════════════════════════════════════════════\nclip_ratios = []\nfor entry in probe_trainer.state.log_history:\n    if \"train/clip_ratio\" in entry:\n        clip_ratios.append(entry[\"train/clip_ratio\"])\n\nnonzero_clips = sum(1 for cr in clip_ratios if cr > 0.0)\nprint(f\"\\n{'='*60}\")\nprint(f\"PROBE RESULTS ({elapsed:.0f}s, {elapsed/10:.0f}s/step)\")\nprint(f\"  clip_ratios: {[f'{cr:.4f}' for cr in clip_ratios]}\")\nprint(f\"  Non-zero clip steps: {nonzero_clips}/{len(clip_ratios)}\")\nprint(f\"  Train loss: {result.training_loss:.4f}\")\nprint(f\"{'='*60}\")\n\nif nonzero_clips >= 3:\n    print(\"βœ“ PROBE GATE PASSED β€” proceed to full training\")\nelif nonzero_clips > 0:\n    print(\"⚠️ MARGINAL β€” clip_ratio > 0 on some steps but < 3. Consider increasing LR or G.\")\nelse:\n    print(\"βœ— PROBE GATE FAILED β€” clip_ratio = 0 on ALL steps.\")\n    print(\"  DO NOT proceed to full training.\")\n    print(\"  See ADR-002 Section 8 (Fallback Plan).\")\n\ndel probe_trainer\ngc.collect(); torch.cuda.empty_cache()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 13: W&B Init + Full Training\n\n**Only run this cell if the probe gate in Cell 12 passed.**"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import wandb\nfrom transformers import TrainerCallback\n\nwandb.login()\nwandb.init(\n    project=WANDB_PROJECT,\n    name=f\"grpo-v4-instruct-0.5B-{time.strftime('%Y%m%d-%H%M')}\",\n    config={\n        \"model_id\": MODEL_ID,\n        \"version\": \"v4\",\n        \"num_generations\": NUM_GENERATIONS,\n        \"max_completion_length\": MAX_COMPLETION_LENGTH,\n        \"temperature\": TEMPERATURE,\n        \"learning_rate\": LEARNING_RATE,\n        \"beta\": BETA,\n        \"scale_rewards\": SCALE_REWARDS,\n        \"batch_size\": BATCH_SIZE,\n        \"grad_accum\": GRAD_ACCUM,\n        \"max_steps\": MAX_STEPS,\n        \"lora_r\": LORA_R,\n        \"lora_alpha\": LORA_ALPHA,\n        \"train_prompts\": len(train_dataset),\n        \"eval_prompts\": len(eval_dataset),\n        \"repetition_penalty_override\": 1.0,\n    },\n)\nprint(f\"βœ“ W&B run: {wandb.run.url}\")\n\n\nclass EvalRewardCallback(TrainerCallback):\n    def __init__(self, eval_records, reward_fn, patience, delta):\n        self.eval_records = eval_records\n        self.reward_fn = reward_fn\n        self.patience = patience\n        self.delta = delta\n        self.best_reward = -float(\"inf\")\n        self.best_step = 0\n        self.no_improve_count = 0\n\n    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n        if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n            return control\n\n        tokenizer_local = processing_class\n        if tokenizer_local is None:\n            print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n            return control\n\n        mean_reward = self._run_eval(model, tokenizer_local, args)\n        improved = mean_reward > self.best_reward + self.delta\n\n        wandb.log({\n            \"eval/mean_reward\": mean_reward,\n            \"eval/best_reward\": max(self.best_reward, mean_reward),\n            \"eval/no_improve_count\": self.no_improve_count,\n        }, step=state.global_step)\n\n        status = \"↑ improved\" if improved else f\"↔ no gain ({self.no_improve_count + 1}/{self.patience})\"\n        print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n\n        if improved:\n            self.best_reward = mean_reward\n            self.best_step = state.global_step\n            self.no_improve_count = 0\n        else:\n            self.no_improve_count += 1\n            if self.no_improve_count >= self.patience:\n                print(f\"[EarlyStopping] No improvement for {self.patience} evals. Halting.\")\n                control.should_training_stop = True\n        return control\n\n    def _run_eval(self, model, tokenizer_local, args):\n        FastLanguageModel.for_inference(model)\n        rewards = []\n        subset = self.eval_records[:EVAL_MAX_SAMPLES]\n        for record in subset:\n            msgs = record[\"prompt\"]\n            text = tokenizer_local.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n            inputs = tokenizer_local(text, return_tensors=\"pt\", truncation=True, max_length=args.max_prompt_length).to(model.device)\n            with torch.no_grad():\n                out = model.generate(\n                    **inputs,\n                    max_new_tokens=EVAL_MAX_TOKENS,\n                    temperature=0.1,   # deterministic eval\n                    do_sample=True,\n                    repetition_penalty=1.0,\n                )\n            resp = tokenizer_local.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n            rewards.append(self.reward_fn([resp], [text])[0])\n        FastLanguageModel.for_training(model)\n        return sum(rewards) / len(rewards) if rewards else 0.0\n\n\n# ── Training ────────────────────────────────────────────────────────────────\nFastLanguageModel.for_training(model)\n\ngrpo_config = GRPOConfig(\n    output_dir=str(CHECKPOINT_DIR),\n    num_generations=NUM_GENERATIONS,\n    scale_rewards=SCALE_REWARDS,\n    max_completion_length=MAX_COMPLETION_LENGTH,\n    max_steps=MAX_STEPS,\n    temperature=TEMPERATURE,\n    beta=BETA,\n    num_train_epochs=1,\n    per_device_train_batch_size=BATCH_SIZE,\n    gradient_accumulation_steps=GRAD_ACCUM,\n    learning_rate=LEARNING_RATE,\n    warmup_ratio=0.1,\n    lr_scheduler_type=\"cosine\",\n    fp16=False,\n    bf16=True,\n    logging_steps=1,\n    save_steps=SAVE_STEPS,\n    save_total_limit=5,\n    save_only_model=True,\n    report_to=\"wandb\",\n    max_prompt_length=MAX_SEQ_LENGTH // 2,\n    seed=42,\n    remove_unused_columns=False,\n    disable_tqdm=True,\n    logging_first_step=True,\n)\n\neval_cb = EvalRewardCallback(\n    eval_records=list(eval_dataset),\n    reward_fn=commerce_reward_fn,\n    patience=EARLY_STOPPING_PATIENCE,\n    delta=EARLY_STOPPING_DELTA,\n)\n\ntrainer = UnslothGRPOTrainer(\n    model=model,\n    reward_funcs=commerce_reward_fn,\n    args=grpo_config,\n    train_dataset=train_dataset,\n    processing_class=tokenizer,\n    callbacks=[eval_cb],\n)\n\nt_start = time.time()\nresult = trainer.train()\nelapsed = time.time() - t_start\n\nwandb.log({\n    \"train/final_loss\": result.training_loss,\n    \"train/duration_hours\": elapsed / 3600,\n    \"train/total_steps\": result.global_step,\n    \"eval/best_reward_final\": eval_cb.best_reward,\n    \"eval/best_step\": eval_cb.best_step,\n})\nwandb.finish()\n\nprint(f\"\\n{'='*60}\")\nprint(f\"V4 Training Complete\")\nprint(f\"  Loss:        {result.training_loss:.4f}\")\nprint(f\"  Steps:       {result.global_step}\")\nprint(f\"  Duration:    {elapsed/3600:.1f}h\")\nprint(f\"  Best eval:   {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\nprint(f\"{'='*60}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 14: Validation (20 Held-Out Samples)\n\nRun validation on held-out samples, broken down by task type.\n\n**Success criteria (ADR-002 Β§7):**\n- Extraction mean reward β‰₯ 0.30\n- Push mean reward β‰₯ 0.40\n- SQL Q&A mean reward β‰₯ 0.20\n- Insights mean reward β‰₯ 0.20"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "FastLanguageModel.for_inference(model)\n\nval_samples = list(eval_dataset)[:20]\nval_results = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n\nfor i, record in enumerate(val_samples):\n    msgs = record[\"prompt\"]\n    user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n    task = _classify_task_type(user_text)\n\n    text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n    inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n    with torch.no_grad():\n        out = model.generate(\n            **inputs,\n            max_new_tokens=MAX_COMPLETION_LENGTH,\n            temperature=0.1,\n            do_sample=True,\n            repetition_penalty=1.0,\n        )\n    resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n    r = commerce_reward_fn([resp], [text])[0]\n    val_results[task].append(r)\n    print(f\"  [{task:12s}] reward={r:.2f} | {strip_think(resp)[:80]}\")\n\nprint(f\"\\n{'='*60}\")\nprint(\"Validation Results by Task:\")\nfor task, rewards in val_results.items():\n    if rewards:\n        mean_r = sum(rewards) / len(rewards)\n        print(f\"  {task:12s}: mean={mean_r:.3f} (n={len(rewards)})\")\nprint(f\"{'='*60}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 15: Save Adapter"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\nmodel.save_pretrained(str(ADAPTER_DIR))\ntokenizer.save_pretrained(str(ADAPTER_DIR))\nprint(f\"βœ“ Adapter saved to {ADAPTER_DIR}\")"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}