v4: ROOT CAUSE FIX — use standard PEFT not Unsloth get_peft_model (fused LoRA kernels have dtype bug #4891). Revert to load_in_4bit=True, dtype=None matching V3.
Browse files
notebooks/v4_instruct_grpo.ipynb
CHANGED
|
@@ -51,7 +51,7 @@
|
|
| 51 |
"execution_count": null,
|
| 52 |
"metadata": {},
|
| 53 |
"outputs": [],
|
| 54 |
-
"source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=
|
| 55 |
},
|
| 56 |
{
|
| 57 |
"cell_type": "markdown",
|
|
|
|
| 51 |
"execution_count": null,
|
| 52 |
"metadata": {},
|
| 53 |
"outputs": [],
|
| 54 |
+
"source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=None, # auto-detect (matches V3 which ran 171 steps successfully)\n)\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# LoRA ADAPTER — ADR-002 §9: r=16, α=32\n# Using standard PEFT instead of Unsloth's get_peft_model() to avoid fused\n# LoRA QKV/O/MLP kernels that have fp16/bf16 dtype mismatch bug (unsloth #4891).\n# V3 avoided this because it loaded a pre-existing adapter (0 QKV/O/MLP patches).\n# V4 applies fresh LoRA → Unsloth patches all 28 layers → crash in matmul_lora.\n# Standard PEFT uses PyTorch native matmul which handles bf16 correctly.\n# ═══════════════════════════════════════════════════════════════════════════════\nfrom peft import LoraConfig, get_peft_model\n\nlora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0,\n bias=\"none\",\n task_type=\"CAUSAL_LM\",\n)\nmodel = get_peft_model(model, lora_config)\nmodel.print_trainable_parameters()\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# CRITICAL OVERRIDES — generation_config ships with values that destroy GRPO\n# Source: Polygl0t/Tucano2-qwen-0.5B-Instruct/generation_config.json\n# temperature: 0.1 → override to 1.0\n# repetition_penalty: 1.2 → override to 1.0\n# use_cache: false → override to true\n# ═══════════════════════════════════════════════════════════════════════════════\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0 # CRITICAL: 1.2 suppresses diversity\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0 # disable top-k — let temperature control diversity\nmodel.generation_config.top_p = 1.0 # disable top-p\nmodel.generation_config.max_length = None # remove conflict with max_new_tokens\n\n# Pad token\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"✓ Model loaded on {model.device}\")\nprint(f\" use_cache: {model.config.use_cache}\")\nprint(f\" temperature: {model.generation_config.temperature}\")\nprint(f\" repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\" top_k: {model.generation_config.top_k}\")\nprint(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# TIED EMBEDDINGS CHECK — ADR-002 Decision 4\n# Source: config.json has \"tie_word_embeddings\": true\n# After LoRA patching, verify lm_head and embed_tokens still share weights.\n# ═══════════════════════════════════════════════════════════════════════════════\n\ntry:\n lm_ptr = model.base_model.model.lm_head.weight.data_ptr()\n embed_ptr = model.base_model.model.model.embed_tokens.weight.data_ptr()\n tied = lm_ptr == embed_ptr\n print(f\" Tied embeddings intact: {tied}\")\n if not tied:\n print(\" ⚠️ WARNING: Tied embeddings broken after LoRA patching. May affect output head gradients.\")\nexcept AttributeError as e:\n print(f\" ⚠️ Could not check tied embeddings: {e}\")"
|
| 55 |
},
|
| 56 |
{
|
| 57 |
"cell_type": "markdown",
|