Add V4 Instruct-Only GRPO notebook implementing ADR-002
Browse filesSingle-model validation on Tucano2-qwen-0.5B-Instruct with all four
task types. Key changes from V3: no <think> overhead, G=16, 512-token
completions, continuous reward design, hard clip_ratio probe gate,
and all generation_config bug fixes (rep_penalty, top_k, use_cache).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- notebooks/v4_instruct_grpo.ipynb +202 -0
notebooks/v4_instruct_grpo.ipynb
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": "# Tucano2 Commerce β GRPO Training V4 (Instruct-Only, 0.5B)\n\n**Reference:** `docs/ADR-002-v4-instruct.md` β Single-Model Validation on 0.5B-Instruct\n\n**V4 key changes over V3:**\n\n| Aspect | V3 | V4 | Why |\n|--------|----|----|-----|\n| Model | 3.7B-Think | **0.5B-Instruct** | Think scored 14.41 NPM vs Instruct's 26.08; clip_ratio=0 on all V3 steps |\n| `<think>` overhead | 2000-3000 tok | **None** | Instruct template has no `<think>` injection |\n| Num generations | 4 | **16** | 0.5B is 8Γ lighter; more G = more reward variance |\n| Completion length | 4096 | **512** | No think overhead; extraction ~100, SQL ~200, insights ~300 tok |\n| Sequence length | 8192 | **2048** | Shorter completions need less context |\n| Reward design | Staged (formatβpartialβtask) | **Continuous** | Simpler; Instruct doesn't need think bonus |\n| Probe gate | 3 steps, no hard gate | **10 steps, clip_ratio β₯ 3/10** | V3's missed signal β hard gate prevents wasted compute |\n| generation_config | Override temp, use_cache | **Override 6 fields** | rep_penalty=1.2 and top_k=50 also destroy GRPO |\n\n**Prerequisites:**\n- Upload `data/pairs/train.jsonl` to `./data/pairs/`\n- Hardware: L4 (24GB), PyTorch kernel, bf16 supported"
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"cell_type": "markdown",
|
| 10 |
+
"metadata": {},
|
| 11 |
+
"source": "---\n\n## Cell 1: Dependencies\n\nRestart kernel first (Kernel β Restart), then run:"
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"outputs": [],
|
| 18 |
+
"source": "# Cell 1 β Clean install\n# Run after kernel restart\n\n!pip install \"unsloth\"\n!pip install \"trl==0.24.0\" --no-deps\n!pip install \"rich\" \"wandb\""
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "markdown",
|
| 22 |
+
"metadata": {},
|
| 23 |
+
"source": "---\n\n## Cell 2: GPU + Unsloth Verification\n\n**Gate:** CUDA available, bf16=True, VRAM > 20GB, TRL 0.24.0."
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": null,
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"outputs": [],
|
| 30 |
+
"source": "import torch\n\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")\nprint(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n\nfrom unsloth import FastLanguageModel\nprint(f\"\\nβ Unsloth loaded\")\n\nimport trl\nassert trl.__version__ == \"0.24.0\", f\"Expected TRL 0.24.0, got {trl.__version__}\"\nprint(f\"β TRL {trl.__version__}\")\n\nimport transformers\nprint(f\"β Transformers {transformers.__version__}\")"
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "markdown",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"source": "---\n\n## Cell 3: Config Constants\n\nAll hyperparameters from ADR-002 Β§9. Every value is annotated with its rationale."
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": "import os\nimport json\nimport re\nimport time\nimport random\nimport gc\nfrom pathlib import Path\n\n# ββ Disable Unsloth kernel recompilation βββββββββββββββββββββββββββββββββββββ\nos.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\n# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nMODEL_ID = \"Polygl0t/Tucano2-qwen-0.5B-Instruct\"\nMAX_SEQ_LENGTH = 2048 # model supports 4096, but 2048 is plenty for Instruct (no <think> overhead)\nADAPTER_DIR = Path(\"models/tucano2-0.5B-instruct-grpo-v4\")\nCHECKPOINT_DIR = ADAPTER_DIR / \"checkpoints\"\n\n# ββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nDATA_DIR = Path(\"data/pairs\")\nTRAIN_FILE = DATA_DIR / \"train.jsonl\"\nEVAL_SPLIT = 0.10 # 10% held out for eval\n\n# ββ GRPO Hyperparameters βββββββββββββββββββββββββββββββββββββββββββββββββββββ\nNUM_GENERATIONS = 16 # 0.5B + short completions = VRAM allows G=16\nMAX_COMPLETION_LENGTH = 512 # Instruct: no <think> overhead. Extraction ~100, SQL ~200, insights ~300\nTEMPERATURE = 1.0 # Skywork-OR1: Ο=1.0 for exploration\nLEARNING_RATE = 2e-6 # Dr. GRPO: 4Γ V2's 5e-7 (clip_ratio=0 β push harder)\nBETA = 0.0 # Dr. GRPO Β§3.2: Ξ²=0 optimal for rule-based rewards\nSCALE_REWARDS = False # Dr. GRPO: remove std normalization bias\nBATCH_SIZE = 2 # per-device batch size\nGRAD_ACCUM = 1 # effective batch = 2 * 1 = 2 prompts * 16 gen = 32 completions\nMAX_STEPS = 200 # validation run\nSAVE_STEPS = 20\nEVAL_STEPS = 10\nEARLY_STOPPING_PATIENCE = 15\nEARLY_STOPPING_DELTA = 0.005\n\n# ββ LoRA βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nLORA_R = 16\nLORA_ALPHA = 32\n\n# ββ Monitoring βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nWANDB_PROJECT = \"tucano2-commerce\"\nEVAL_MAX_SAMPLES = 15 # eval callback samples\nEVAL_MAX_TOKENS = 512 # match training completion length\n\n# ββ Task Classification (inherited from V2/V3) ββββββββββββββββββββββββββββββ\nVALID_SENTIMENTS = {\"positive\", \"negative\", \"neutral\"}\nVALID_CATEGORIES = {\n \"delivery_delay\", \"product_quality\", \"product_not_received\",\n \"wrong_product\", \"seller_communication\", \"app_issue\",\n \"price_value\", \"other\", \"none\",\n}\nVALID_CHURN = {\"low\", \"medium\", \"high\"}\nVALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\nEXTRACTION_FIELDS = [\n \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n \"product_issue\", \"seller_issue\", \"main_complaint\",\n \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n]\n\n# ββ Verified Special Token IDs (from tokenizer_config.json) βββββββββββββββββ\n# These are constants β do NOT recompute via tokenizer.encode()\nTOKEN_ID_BOS = 1 # <|im_start|>\nTOKEN_ID_EOS = 2 # <|im_end|>\nTOKEN_ID_PAD = 49109 # <|pad|>\nTOKEN_ID_THINK = 49116 # <think>\nTOKEN_ID_THINK_END = 49117 # </think>\n\nprint(\"β Config loaded\")\nprint(f\" Model: {MODEL_ID}\")\nprint(f\" G={NUM_GENERATIONS}, max_comp={MAX_COMPLETION_LENGTH}, temp={TEMPERATURE}\")\nprint(f\" LR={LEARNING_RATE}, Ξ²={BETA}, scale_rewards={SCALE_REWARDS}\")\nprint(f\" LoRA r={LORA_R}, Ξ±={LORA_ALPHA}\")\nprint(f\" Max steps: {MAX_STEPS}\")"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"source": "---\n\n## Cell 4: Load Model + Apply Critical Overrides\n\n**Gate:** Model loaded, `use_cache=True`, `repetition_penalty=1.0`, `temperature=1.0`.\n\n**Bugs fixed here (ADR-002 Β§1.2):**\n- Bug 2a: `use_cache: false` in config.json β override to True\n- Bug 2b: `repetition_penalty: 1.2` in generation_config β override to 1.0\n- Bug 2c: `temperature: 0.1` in generation_config β override to 1.0\n- Bug 2d: Tied embeddings check after LoRA patching"
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": null,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=None, # auto-detect\n)\n\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# LoRA ADAPTER β ADR-002 Β§9: r=16, Ξ±=32\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n random_state=42,\n)\n\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# CRITICAL OVERRIDES β generation_config ships with values that destroy GRPO\n# Source: Polygl0t/Tucano2-qwen-0.5B-Instruct/generation_config.json\n# temperature: 0.1 β override to 1.0\n# repetition_penalty: 1.2 β override to 1.0\n# use_cache: false β override to true\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0 # CRITICAL: 1.2 suppresses diversity\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0 # disable top-k β let temperature control diversity\nmodel.generation_config.top_p = 1.0 # disable top-p\n\n# Pad token\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"β Model loaded on {model.device}\")\nprint(f\" use_cache: {model.config.use_cache}\")\nprint(f\" temperature: {model.generation_config.temperature}\")\nprint(f\" repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\" top_k: {model.generation_config.top_k}\")\nprint(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# TIED EMBEDDINGS CHECK β ADR-002 Decision 4\n# Source: config.json has \"tie_word_embeddings\": true\n# After LoRA patching, verify lm_head and embed_tokens still share weights.\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\ntry:\n lm_ptr = model.lm_head.weight.data_ptr()\n embed_ptr = model.model.embed_tokens.weight.data_ptr()\n tied = lm_ptr == embed_ptr\n print(f\" Tied embeddings intact: {tied}\")\n if not tied:\n print(\" β οΈ WARNING: Tied embeddings broken after LoRA patching. May affect output head gradients.\")\nexcept AttributeError as e:\n print(f\" β οΈ Could not check tied embeddings: {e}\")"
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "markdown",
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"source": "---\n\n## Cell 5: Token ID Verification\n\n**Gate:** All token IDs match. Single-token `<think>` (49116) and `</think>` (49117) confirmed."
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"execution_count": null,
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": "# Verify that the constants from Cell 3 match the actual tokenizer\n# Do NOT skip this cell β if IDs don't match, all reward functions break\n\ntok_tests = {\n \"<|im_start|>\": TOKEN_ID_BOS,\n \"<|im_end|>\": TOKEN_ID_EOS,\n \"<|pad|>\": TOKEN_ID_PAD,\n \"<think>\": TOKEN_ID_THINK,\n \"</think>\": TOKEN_ID_THINK_END,\n}\n\nall_pass = True\nfor text, expected_id in tok_tests.items():\n ids = tokenizer.encode(text, add_special_tokens=False)\n actual_id = ids[0] if len(ids) == 1 else ids\n match = (len(ids) == 1 and ids[0] == expected_id)\n status = \"β\" if match else \"β\"\n print(f\" {status} '{text}' β expected {expected_id}, got {actual_id}\")\n if not match:\n all_pass = False\n\nassert all_pass, \"Token ID mismatch detected. Update constants in Cell 3 before proceeding.\"\nprint(\"\\nβ All token IDs verified\")\n\nassert tokenizer.eos_token_id == TOKEN_ID_EOS, f\"eos_token_id mismatch: {tokenizer.eos_token_id}\"\nprint(f\"β eos_token_id = {tokenizer.eos_token_id}\")"
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"source": "---\n\n## Cell 6: KV Cache Diagnostic\n\n**Gate:** Ratio < 3Γ β KV cache OK. Ratio > 5Γ β BROKEN, abort."
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": null,
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": "# Copied from V2 Cell 5b β verify KV cache is working\n\nFastLanguageModel.for_inference(model)\n\n_kv_msgs = [{\"role\": \"user\", \"content\": \"Qual a categoria de reclamaΓ§Γ£o mais frequente?\"}]\n_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n\n_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\nwith torch.no_grad():\n for _step in range(50):\n _t0 = time.time()\n seq_len = _generated.shape[1]\n if _past is None:\n _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n else:\n _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n _out = model(\n input_ids=_generated[:, -1:] if _past else _generated,\n position_ids=_position_ids,\n attention_mask=torch.ones(1, seq_len, device=model.device),\n past_key_values=_past,\n use_cache=True,\n return_dict=True,\n )\n _past = _out.past_key_values\n _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n _generated = torch.cat([_generated, _next], dim=1)\n _token_times.append(time.time() - _t0)\n\n_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\nprint(f\"First 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\nprint(f\"Last 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\nprint(f\"Ratio last/first: {_ratio:.1f}x\")\nassert _ratio < 5, f\"KV cache BROKEN (ratio {_ratio:.1f}Γ). Check model.config.use_cache.\"\nprint(\"β KV cache working correctly\")\n\ndel _past, _generated, _kv_inputs, _token_times, _out\ngc.collect()\ntorch.cuda.empty_cache()"
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "markdown",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"source": "---\n\n## Cell 7: Single Inference Test\n\n**Gate:** Response is coherent Portuguese. Check if `<think>` appears. Check if JSON structure present."
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "code",
|
| 87 |
+
"execution_count": null,
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"outputs": [],
|
| 90 |
+
"source": "FastLanguageModel.for_inference(model)\n\ntest_msgs = [\n {\"role\": \"system\", \"content\": \"VocΓͺ Γ© um assistente de IA especializado em e-commerce brasileiro.\"},\n {\"role\": \"user\", \"content\": \"Analise esta avaliaΓ§Γ£o: 'Produto chegou quebrado, pΓ©ssima embalagem. Nunca mais compro aqui.' Retorne um objeto JSON com os campos: sentiment, sentiment_score, delivery_issue, complaint_category.\"},\n]\ntext = tokenizer.apply_chat_template(test_msgs, tokenize=False, add_generation_prompt=True)\ninputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n\nt0 = time.time()\noutputs = model.generate(\n **inputs,\n max_new_tokens=256,\n temperature=0.1, # low temp for deterministic eval\n do_sample=True,\n repetition_penalty=1.0,\n)\nelapsed = time.time() - t0\n\nresponse = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\nprint(f\"Generation time: {elapsed:.1f}s\")\nprint(f\"Response length: {len(response)} chars\")\nprint(f\"Contains <think>: {'<think>' in response}\")\nprint(f\"Contains JSON {{ }}: {'{' in response and '}' in response}\")\nprint(f\"\\n{'='*60}\")\nprint(response[:500])"
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "markdown",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"source": "---\n\n## Cell 8: Reward Functions\n\nComplete reward functions from ADR-002 Β§6. All functions call `strip_think()` defensively,\neven for the Instruct model (which shouldn't generate `<think>` but might spontaneously)."
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": "def strip_think(text: str) -> str:\n \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n\n\ndef has_think_block(text: str) -> bool:\n return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n\n\ndef _classify_task_type(prompt_text: str) -> str:\n p = prompt_text.lower()\n if \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n return \"extraction\"\n elif \"notificaΓ§Γ£o push\" in p or \"notificaΓ§Γ£o de reengajamento\" in p:\n return \"push\"\n elif \"perfil do cliente\" in p or \"retenΓ§Γ£o\" in p or \"anΓ‘lise\" in p or \"insight\" in p:\n return \"insights\"\n else:\n return \"sql_qa\"\n\n\ndef _extract_json(text: str) -> dict | None:\n \"\"\"Extract first JSON object from text. Returns parsed dict or None.\"\"\"\n stripped = text.strip()\n stripped = re.sub(r\"^```(?:json)?\\s*\", \"\", stripped)\n stripped = re.sub(r\"\\s*```$\", \"\", stripped)\n stripped = stripped.strip()\n try:\n return json.loads(stripped)\n except (json.JSONDecodeError, TypeError):\n pass\n match = re.search(r\"\\{[^{}]*(?:\\{[^{}]*\\}[^{}]*)*\\}\", text, re.DOTALL)\n if match:\n try:\n return json.loads(match.group())\n except (json.JSONDecodeError, TypeError):\n pass\n return None\n\n\ndef reward_extraction(completion: str) -> float:\n \"\"\"Continuous reward for extraction tasks (max 1.0).\"\"\"\n answer = strip_think(completion)\n data = _extract_json(answer)\n\n if data is None:\n if \"{\" in answer and \"}\" in answer:\n return 0.05\n return 0.0\n\n if not isinstance(data, dict):\n return 0.1 # valid JSON but not an object\n\n score = 0.3 # valid JSON object\n\n # Schema completeness (0.3 total)\n present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n score += 0.3 * (present / len(EXTRACTION_FIELDS))\n\n # Value validity (0.4 total, split across checks)\n checks_passed = 0\n checks_total = 0\n\n for field, validator in [\n (\"sentiment\", lambda v: v in VALID_SENTIMENTS),\n (\"complaint_category\", lambda v: v in VALID_CATEGORIES),\n (\"churn_risk\", lambda v: v in VALID_CHURN),\n (\"repeat_intent\", lambda v: v in VALID_REPEAT),\n (\"sentiment_score\", lambda v: isinstance(v, (int, float)) and 1 <= v <= 5),\n ]:\n checks_total += 1\n if field in data and validator(data[field]):\n checks_passed += 1\n\n for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n checks_total += 1\n if bool_field in data and isinstance(data[bool_field], bool):\n checks_passed += 1\n\n if checks_total > 0:\n score += 0.4 * (checks_passed / checks_total)\n\n return min(score, 1.0)\n\n\ndef reward_sql_qa(completion: str) -> float:\n \"\"\"Continuous reward for SQL Q&A (max 1.0).\"\"\"\n answer = strip_think(completion)\n if not answer.strip():\n return 0.0\n\n score = 0.0\n\n # Numerical content (more numbers = more specific answer)\n numbers = re.findall(r\"\\d+(?:[.,]\\d+)?\", answer)\n score += min(0.4, 0.1 * len(numbers))\n\n # Length: 50-500 chars optimal\n length = len(answer)\n if 50 <= length <= 500:\n score += 0.3\n elif length > 0:\n score += 0.3 * max(0, 1 - abs(length - 275) / 275)\n\n # Portuguese business vocabulary\n pt_business = [\"pedidos\", \"clientes\", \"mΓ©dia\", \"total\", \"taxa\", \"vendas\",\n \"produtos\", \"perΓodo\", \"categoria\", \"regiΓ£o\", \"faturamento\"]\n pt_matches = sum(1 for w in pt_business if w in answer.lower())\n score += min(0.3, 0.06 * pt_matches)\n\n return min(score, 1.0)\n\n\ndef reward_insights(completion: str) -> float:\n \"\"\"Continuous reward for insights (max 1.0).\"\"\"\n answer = strip_think(completion)\n if not answer.strip():\n return 0.0\n\n score = 0.0\n\n # Actionable language\n action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n \"priorizar\", \"investir\", \"otimizar\", \"estratΓ©gi\", \"aΓ§Γ£o\"]\n matches = sum(1 for w in action_words if w in answer.lower())\n score += min(0.4, 0.08 * matches)\n\n # Length: 100-800 chars optimal\n length = len(answer)\n if 100 <= length <= 800:\n score += 0.3\n elif length > 0:\n score += 0.3 * max(0, 1 - abs(length - 450) / 450)\n\n # Structure: bullet points, numbered lists, headers\n structure_marks = len(re.findall(r\"^[-β’*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n score += min(0.2, 0.04 * structure_marks)\n\n # Portuguese coherence marker\n if any(w in answer.lower() for w in [\"cliente\", \"produto\", \"serviΓ§o\", \"empresa\"]):\n score += 0.1\n\n return min(score, 1.0)\n\n\ndef reward_push(completion: str) -> float:\n \"\"\"Continuous reward for push notifications (max 1.0).\"\"\"\n answer = strip_think(completion).strip()\n if not answer:\n return 0.0\n\n # Length: β€120 chars gets full credit\n length = len(answer)\n if length <= 120:\n length_score = 0.5\n else:\n length_score = 0.5 * max(0, 1 - (length - 120) / 120)\n\n # Portuguese content\n pt_markers = re.findall(r\"[ãçéΓͺΓ³ΓΊΓ’Γ΅]|vocΓͺ|para|como|seu|sua|oferta|desconto|produto\",\n answer, re.IGNORECASE)\n lang_score = min(0.3, 0.03 * len(pt_markers))\n\n # Non-generic (penalize very generic phrases)\n generic = [\"olΓ‘\", \"obrigado pela compra\", \"agradecemos\"]\n is_generic = any(g in answer.lower() for g in generic)\n creativity_score = 0.0 if is_generic else 0.2\n\n return min(length_score + lang_score + creativity_score, 1.0)\n\n\ndef commerce_reward_fn(completions, prompts, **kwargs) -> list[float]:\n \"\"\"Master reward function: dispatches by task type.\"\"\"\n rewards = []\n for completion, prompt in zip(completions, prompts):\n if isinstance(completion, list):\n comp_text = completion[-1][\"content\"] if completion else \"\"\n else:\n comp_text = str(completion)\n\n if isinstance(prompt, list):\n prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n else:\n prompt_text = str(prompt)\n\n task = _classify_task_type(prompt_text)\n\n if task == \"extraction\":\n rewards.append(reward_extraction(comp_text))\n elif task == \"sql_qa\":\n rewards.append(reward_sql_qa(comp_text))\n elif task == \"insights\":\n rewards.append(reward_insights(comp_text))\n elif task == \"push\":\n rewards.append(reward_push(comp_text))\n else:\n r = 0.2 if comp_text.strip() else 0.0\n rewards.append(r)\n\n return rewards\n\n\nprint(\"β Reward functions defined\")"
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "markdown",
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"source": "---\n\n## Cell 9: Reward Calibration\n\n**Gate:** Mean reward < 0.90 (if already ~1.0, reward is too easy). Variance > 0. Document whether `<think>` appeared."
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": "# Load data, classify by task type, run calibration on 8 diverse samples\n\nby_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\nwith open(TRAIN_FILE) as f:\n for line in f:\n row = json.loads(line)\n convs = row[\"conversations\"]\n prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n if not prompt_msgs:\n continue\n user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n by_type[task].append(prompt_msgs)\n\nprint(f\"Prompts by type: {', '.join(f'{k}={len(v)}' for k, v in by_type.items())}\")\n\n# Pick 2 samples per task type = 8 total\nrng = random.Random(42)\ncal_samples = []\nfor task_type in by_type:\n pool = by_type[task_type]\n if len(pool) >= 2:\n cal_samples.extend(rng.sample(pool, 2))\n elif pool:\n cal_samples.extend(pool)\n\nFastLanguageModel.for_inference(model)\nprint(f\"\\nReward calibration ({len(cal_samples)} samples):\")\nprint(\"-\" * 60)\n\ncal_rewards = []\nfor i, msgs in enumerate(cal_samples):\n text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n outputs = model.generate(\n **inputs,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.7,\n do_sample=True,\n repetition_penalty=1.0,\n )\n response = tokenizer.decode(outputs[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n r = commerce_reward_fn([response], [text])[0]\n cal_rewards.append(r)\n task = _classify_task_type(\" \".join(m.get(\"content\", \"\") for m in msgs if m[\"role\"] == \"user\"))\n has_think = \"<think>\" in response\n answer_preview = strip_think(response)[:100]\n print(f\" Sample {i+1} [{task:12s}]: reward={r:.2f} | has_think={has_think} | {answer_preview}\")\n\nprint(f\"\\nMean={sum(cal_rewards)/len(cal_rewards):.2f}, Min={min(cal_rewards):.2f}, Max={max(cal_rewards):.2f}\")\nprint(f\"Reward variance > 0: {len(set(f'{r:.4f}' for r in cal_rewards)) > 1}\")"
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"source": "---\n\n## Cell 10: Dataset Preparation\n\n**Gate:** Train has ~1,650 prompts, eval has ~180. All 4 task types present in both."
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "code",
|
| 123 |
+
"execution_count": null,
|
| 124 |
+
"metadata": {},
|
| 125 |
+
"outputs": [],
|
| 126 |
+
"source": "from datasets import Dataset\n\ndef prepare_datasets(train_file, eval_ratio=EVAL_SPLIT, seed=42):\n rng = random.Random(seed)\n\n all_records = []\n with open(train_file) as f:\n for line in f:\n row = json.loads(line)\n convs = row[\"conversations\"]\n prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n if prompt_msgs:\n all_records.append(prompt_msgs)\n\n rng.shuffle(all_records)\n n_eval = max(1, int(len(all_records) * eval_ratio))\n eval_records = all_records[:n_eval]\n train_records = all_records[n_eval:]\n\n # Log task distribution\n for label, records in [(\"train\", train_records), (\"eval\", eval_records)]:\n dist = {}\n for msgs in records:\n user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n dist[task] = dist.get(task, 0) + 1\n print(f\" {label}: {len(records)} prompts β {dist}\")\n\n train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n return train_ds, eval_ds\n\ntrain_dataset, eval_dataset = prepare_datasets(TRAIN_FILE)\nprint(f\"\\nβ Datasets: train={len(train_dataset)}, eval={len(eval_dataset)}\")"
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": "---\n\n## Cell 11: Smoke Test (1 Step)\n\n**Gate:** No OOM. Peak VRAM < 20GB. Step time < 180s."
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "code",
|
| 135 |
+
"execution_count": null,
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"outputs": [],
|
| 138 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\nFastLanguageModel.for_training(model)\n\nsmoke_config = GRPOConfig(\n output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n num_generations=NUM_GENERATIONS,\n scale_rewards=SCALE_REWARDS,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=1,\n temperature=TEMPERATURE,\n beta=BETA,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=1,\n learning_rate=LEARNING_RATE,\n fp16=False,\n bf16=True,\n logging_steps=1,\n save_steps=999,\n report_to=\"none\",\n max_prompt_length=MAX_SEQ_LENGTH // 2,\n seed=42,\n remove_unused_columns=False,\n)\n\n\nclass UnslothGRPOTrainer(GRPOTrainer):\n \"\"\"Wraps generation with Unsloth for_inference()/for_training().\"\"\"\n def _generate(self, prompts, images):\n FastLanguageModel.for_inference(self.model)\n try:\n result = super()._generate(prompts, images)\n finally:\n FastLanguageModel.for_training(self.model)\n return result\n\n\nsmoke_trainer = UnslothGRPOTrainer(\n model=model,\n reward_funcs=commerce_reward_fn,\n args=smoke_config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n)\n\nt0 = time.time()\nsmoke_trainer.train()\nstep_time = time.time() - t0\n\npeak_vram = torch.cuda.max_memory_allocated() / 1e9\nprint(f\"\\nβ Smoke test passed!\")\nprint(f\" Step time: {step_time:.0f}s\")\nprint(f\" Peak VRAM: {peak_vram:.1f}GB / {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB\")\nprint(f\" Estimated full run ({MAX_STEPS} steps): {step_time * MAX_STEPS / 3600:.1f}h\")\n\ndel smoke_trainer\ngc.collect(); torch.cuda.empty_cache()"
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"cell_type": "markdown",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"source": "---\n\n## Cell 12: Probe Run (10 Steps) β THE CRITICAL GATE\n\n**Gate:** `nonzero_clips >= 3`. If this fails, go to ADR-002 Section 8 (Fallback Plan).\n\nThis is the gate V3 missed. If clip_ratio = 0 on all steps, the policy is not learning.\nDo NOT proceed to full training without passing this gate."
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"cell_type": "code",
|
| 147 |
+
"execution_count": null,
|
| 148 |
+
"metadata": {},
|
| 149 |
+
"outputs": [],
|
| 150 |
+
"source": "FastLanguageModel.for_training(model)\n\nprobe_config = GRPOConfig(\n output_dir=str(CHECKPOINT_DIR / \"probe\"),\n num_generations=NUM_GENERATIONS,\n scale_rewards=SCALE_REWARDS,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=10,\n temperature=TEMPERATURE,\n beta=BETA,\n num_train_epochs=1,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LEARNING_RATE,\n warmup_ratio=0.1,\n lr_scheduler_type=\"cosine\",\n fp16=False,\n bf16=True,\n logging_steps=1,\n save_steps=999,\n report_to=\"none\",\n max_prompt_length=MAX_SEQ_LENGTH // 2,\n seed=42,\n remove_unused_columns=False,\n)\n\nprobe_trainer = UnslothGRPOTrainer(\n model=model,\n reward_funcs=commerce_reward_fn,\n args=probe_config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n)\n\nt0 = time.time()\nresult = probe_trainer.train()\nelapsed = time.time() - t0\n\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# CRITICAL GATE: clip_ratio > 0 on at least 3 of 10 steps\n# If this fails, STOP. See Fallback Plan (Section 8 of ADR-002).\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nclip_ratios = []\nfor entry in probe_trainer.state.log_history:\n if \"train/clip_ratio\" in entry:\n clip_ratios.append(entry[\"train/clip_ratio\"])\n\nnonzero_clips = sum(1 for cr in clip_ratios if cr > 0.0)\nprint(f\"\\n{'='*60}\")\nprint(f\"PROBE RESULTS ({elapsed:.0f}s, {elapsed/10:.0f}s/step)\")\nprint(f\" clip_ratios: {[f'{cr:.4f}' for cr in clip_ratios]}\")\nprint(f\" Non-zero clip steps: {nonzero_clips}/{len(clip_ratios)}\")\nprint(f\" Train loss: {result.training_loss:.4f}\")\nprint(f\"{'='*60}\")\n\nif nonzero_clips >= 3:\n print(\"β PROBE GATE PASSED β proceed to full training\")\nelif nonzero_clips > 0:\n print(\"β οΈ MARGINAL β clip_ratio > 0 on some steps but < 3. Consider increasing LR or G.\")\nelse:\n print(\"β PROBE GATE FAILED β clip_ratio = 0 on ALL steps.\")\n print(\" DO NOT proceed to full training.\")\n print(\" See ADR-002 Section 8 (Fallback Plan).\")\n\ndel probe_trainer\ngc.collect(); torch.cuda.empty_cache()"
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"cell_type": "markdown",
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"source": "---\n\n## Cell 13: W&B Init + Full Training\n\n**Only run this cell if the probe gate in Cell 12 passed.**"
|
| 156 |
+
},
|
| 157 |
+
{
|
| 158 |
+
"cell_type": "code",
|
| 159 |
+
"execution_count": null,
|
| 160 |
+
"metadata": {},
|
| 161 |
+
"outputs": [],
|
| 162 |
+
"source": "import wandb\nfrom transformers import TrainerCallback\n\nwandb.login()\nwandb.init(\n project=WANDB_PROJECT,\n name=f\"grpo-v4-instruct-0.5B-{time.strftime('%Y%m%d-%H%M')}\",\n config={\n \"model_id\": MODEL_ID,\n \"version\": \"v4\",\n \"num_generations\": NUM_GENERATIONS,\n \"max_completion_length\": MAX_COMPLETION_LENGTH,\n \"temperature\": TEMPERATURE,\n \"learning_rate\": LEARNING_RATE,\n \"beta\": BETA,\n \"scale_rewards\": SCALE_REWARDS,\n \"batch_size\": BATCH_SIZE,\n \"grad_accum\": GRAD_ACCUM,\n \"max_steps\": MAX_STEPS,\n \"lora_r\": LORA_R,\n \"lora_alpha\": LORA_ALPHA,\n \"train_prompts\": len(train_dataset),\n \"eval_prompts\": len(eval_dataset),\n \"repetition_penalty_override\": 1.0,\n },\n)\nprint(f\"β W&B run: {wandb.run.url}\")\n\n\nclass EvalRewardCallback(TrainerCallback):\n def __init__(self, eval_records, reward_fn, patience, delta):\n self.eval_records = eval_records\n self.reward_fn = reward_fn\n self.patience = patience\n self.delta = delta\n self.best_reward = -float(\"inf\")\n self.best_step = 0\n self.no_improve_count = 0\n\n def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n return control\n\n tokenizer_local = processing_class\n if tokenizer_local is None:\n print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n return control\n\n mean_reward = self._run_eval(model, tokenizer_local, args)\n improved = mean_reward > self.best_reward + self.delta\n\n wandb.log({\n \"eval/mean_reward\": mean_reward,\n \"eval/best_reward\": max(self.best_reward, mean_reward),\n \"eval/no_improve_count\": self.no_improve_count,\n }, step=state.global_step)\n\n status = \"β improved\" if improved else f\"β no gain ({self.no_improve_count + 1}/{self.patience})\"\n print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n\n if improved:\n self.best_reward = mean_reward\n self.best_step = state.global_step\n self.no_improve_count = 0\n else:\n self.no_improve_count += 1\n if self.no_improve_count >= self.patience:\n print(f\"[EarlyStopping] No improvement for {self.patience} evals. Halting.\")\n control.should_training_stop = True\n return control\n\n def _run_eval(self, model, tokenizer_local, args):\n FastLanguageModel.for_inference(model)\n rewards = []\n subset = self.eval_records[:EVAL_MAX_SAMPLES]\n for record in subset:\n msgs = record[\"prompt\"]\n text = tokenizer_local.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n inputs = tokenizer_local(text, return_tensors=\"pt\", truncation=True, max_length=args.max_prompt_length).to(model.device)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=EVAL_MAX_TOKENS,\n temperature=0.1, # deterministic eval\n do_sample=True,\n repetition_penalty=1.0,\n )\n resp = tokenizer_local.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n rewards.append(self.reward_fn([resp], [text])[0])\n FastLanguageModel.for_training(model)\n return sum(rewards) / len(rewards) if rewards else 0.0\n\n\n# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nFastLanguageModel.for_training(model)\n\ngrpo_config = GRPOConfig(\n output_dir=str(CHECKPOINT_DIR),\n num_generations=NUM_GENERATIONS,\n scale_rewards=SCALE_REWARDS,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS,\n temperature=TEMPERATURE,\n beta=BETA,\n num_train_epochs=1,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LEARNING_RATE,\n warmup_ratio=0.1,\n lr_scheduler_type=\"cosine\",\n fp16=False,\n bf16=True,\n logging_steps=1,\n save_steps=SAVE_STEPS,\n save_total_limit=5,\n save_only_model=True,\n report_to=\"wandb\",\n max_prompt_length=MAX_SEQ_LENGTH // 2,\n seed=42,\n remove_unused_columns=False,\n disable_tqdm=True,\n logging_first_step=True,\n)\n\neval_cb = EvalRewardCallback(\n eval_records=list(eval_dataset),\n reward_fn=commerce_reward_fn,\n patience=EARLY_STOPPING_PATIENCE,\n delta=EARLY_STOPPING_DELTA,\n)\n\ntrainer = UnslothGRPOTrainer(\n model=model,\n reward_funcs=commerce_reward_fn,\n args=grpo_config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n callbacks=[eval_cb],\n)\n\nt_start = time.time()\nresult = trainer.train()\nelapsed = time.time() - t_start\n\nwandb.log({\n \"train/final_loss\": result.training_loss,\n \"train/duration_hours\": elapsed / 3600,\n \"train/total_steps\": result.global_step,\n \"eval/best_reward_final\": eval_cb.best_reward,\n \"eval/best_step\": eval_cb.best_step,\n})\nwandb.finish()\n\nprint(f\"\\n{'='*60}\")\nprint(f\"V4 Training Complete\")\nprint(f\" Loss: {result.training_loss:.4f}\")\nprint(f\" Steps: {result.global_step}\")\nprint(f\" Duration: {elapsed/3600:.1f}h\")\nprint(f\" Best eval: {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\nprint(f\"{'='*60}\")"
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "markdown",
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"source": "---\n\n## Cell 14: Validation (20 Held-Out Samples)\n\nRun validation on held-out samples, broken down by task type.\n\n**Success criteria (ADR-002 Β§7):**\n- Extraction mean reward β₯ 0.30\n- Push mean reward β₯ 0.40\n- SQL Q&A mean reward β₯ 0.20\n- Insights mean reward β₯ 0.20"
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"cell_type": "code",
|
| 171 |
+
"execution_count": null,
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"outputs": [],
|
| 174 |
+
"source": "FastLanguageModel.for_inference(model)\n\nval_samples = list(eval_dataset)[:20]\nval_results = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n\nfor i, record in enumerate(val_samples):\n msgs = record[\"prompt\"]\n user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n\n text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.1,\n do_sample=True,\n repetition_penalty=1.0,\n )\n resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n r = commerce_reward_fn([resp], [text])[0]\n val_results[task].append(r)\n print(f\" [{task:12s}] reward={r:.2f} | {strip_think(resp)[:80]}\")\n\nprint(f\"\\n{'='*60}\")\nprint(\"Validation Results by Task:\")\nfor task, rewards in val_results.items():\n if rewards:\n mean_r = sum(rewards) / len(rewards)\n print(f\" {task:12s}: mean={mean_r:.3f} (n={len(rewards)})\")\nprint(f\"{'='*60}\")"
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"cell_type": "markdown",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"source": "---\n\n## Cell 15: Save Adapter"
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "code",
|
| 183 |
+
"execution_count": null,
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"outputs": [],
|
| 186 |
+
"source": "ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\nmodel.save_pretrained(str(ADAPTER_DIR))\ntokenizer.save_pretrained(str(ADAPTER_DIR))\nprint(f\"β Adapter saved to {ADAPTER_DIR}\")"
|
| 187 |
+
}
|
| 188 |
+
],
|
| 189 |
+
"metadata": {
|
| 190 |
+
"kernelspec": {
|
| 191 |
+
"display_name": "Python 3",
|
| 192 |
+
"language": "python",
|
| 193 |
+
"name": "python3"
|
| 194 |
+
},
|
| 195 |
+
"language_info": {
|
| 196 |
+
"name": "python",
|
| 197 |
+
"version": "3.10.0"
|
| 198 |
+
}
|
| 199 |
+
},
|
| 200 |
+
"nbformat": 4,
|
| 201 |
+
"nbformat_minor": 5
|
| 202 |
+
}
|