File size: 90,795 Bytes
c5f1d2d 41eb15f 63b1c86 41eb15f c5f1d2d 41eb15f 63b1c86 41eb15f 63b1c86 41eb15f 63b1c86 41eb15f b1be31c 41eb15f b1be31c 41eb15f b1be31c 41eb15f b1be31c 41eb15f b1be31c 41eb15f c5f1d2d 0fc9042 c5f1d2d b1be31c c5f1d2d 080fd9a c5f1d2d c95e44c c5f1d2d c95e44c c5f1d2d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# Tucano2 Commerce β GRPO Training V4.2 (Gold Standard, 0.5B)\n\n**Reference:** `docs/v4_2-handoff.md` \n**Base:** V4.1 notebook with 8 targeted changes\n\n## V4.2 Changes from V4.1\n\n| # | Change | V4.1 | V4.2 | Why |\n|---|--------|------|------|-----|\n| 1 | Eval suite | 15 mixed samples | **65 stratified** (20 ext + 15 sql + 15 ins + 15 push) | Insights swing Β±0.22 was eval noise on nβ2 |\n| 2 | Reward audit | None | **Spearman Ο > 0.70 gate** (20 completions, human-scored) | Parser bug persisted 3 versions |\n| 3 | SQL reward | Heuristic vocabulary | **Validation-aware** (SQL syntax + query/explanation + numerics + domain) | SQL stagnant +3.8% β reward was ceiling |\n| 4 | Max steps | 600 | **1,500** (~2.5 epochs) | Only 40% data seen; eval still improving at step 500 |\n| 5 | GDPO normalization | Batch-level reward | **Per-component normalization** before aggregation | GDPO Β§3.1: preserves 4Γ more advantage groups |\n| 6 | Task weighting | Equal (0.25 each) | **Dynamic IWU** (upweight stagnating tasks) | MT-GRPO Β§3.2: prevents easy-task collapse |\n| 7 | Seeds | Single run (42) | **3 seeds** (42, 123, 456) with reported CIs | Minimum for credible ML result |\n| 8 | Best checkpoint | Save at end | **Explicit best_checkpoint/** on eval improvement | GRPOTrainer lacks load_best_model_at_end |\n\n**Prerequisites:**\n- Upload `data/pairs/train.jsonl` and `data/pairs/eval.jsonl` to `./data/pairs/`\n- Hardware: L4 (24GB), PyTorch kernel, bf16 supported\n- Estimated runtime: ~12h per seed (1,500 steps Γ ~30s/step)\n- Run 3 times, changing only `CURRENT_SEED` in Cell 3\n\n---\n\n*V4.2 is the last 0.5B run. Its purpose is not to find more improvement β it is to know exactly what was found and why, with enough statistical rigor to say so in writing.*"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 1: Dependencies\n\n**Gate:** No errors. Verify TRL 0.24.0 installed.\n\n**V4.2 additions:** `scipy` for Spearman Ο in reward audit."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# Cell 1 β Clean install\n# Run after kernel restart\n\n!pip install \"unsloth\"\n!pip install \"trl==0.24.0\" --no-deps\n!pip install \"rich\" \"wandb\"\n!pip install \"json-repair\" # V4.1: robust JSON parser for Portuguese LLM output\n!pip install \"scipy\" # V4.2: Spearman Ο for reward audit"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 2: GPU + Unsloth Verification\n\n**Gate:** CUDA available, bf16=True, VRAM > 20GB, TRL 0.24.0."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import torch\n\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\nprint(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n\nfrom unsloth import FastLanguageModel\nprint(f\"\\nβ Unsloth loaded\")\n\nimport trl\nassert trl.__version__ == \"0.24.0\", f\"Expected TRL 0.24.0, got {trl.__version__}\"\nprint(f\"β TRL {trl.__version__}\")\n\nimport transformers\nprint(f\"β Transformers {transformers.__version__}\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 3: Config Constants\n\n**V4.2 changes:**\n- `MAX_STEPS`: 600 β **1,500** (multi-epoch, ~2.5Γ full dataset)\n- `EVAL_STEPS`: 20 β **50** (more frequent eval relative to epoch boundaries)\n- `SAVE_STEPS`: 50 β **100** (scaled for longer run)\n- `SEEDS`: Added multi-seed support. Change `CURRENT_SEED` per run.\n- `EVAL_TOTAL = 65`: Stratified eval set (20 ext + 15 sql + 15 ins + 15 push)\n\n**Everything else UNCHANGED from V4.1** (validated config)."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import os\nimport json\nimport re\nimport time\nimport random\nimport gc\nimport math\nimport warnings\nfrom pathlib import Path\n\n# ββ Suppress noisy deprecation warnings from Transformers 5.5.0 ββββββββββββββ\nwarnings.filterwarnings(\"ignore\", message=\".*AttentionMaskConverter.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*Passing `generation_config` together with.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*max_new_tokens.*max_length.*\")\nwarnings.filterwarnings(\"ignore\", category=FutureWarning)\n\n# ββ Disable Unsloth kernel recompilation βββββββββββββββββββββββββββββββββββββ\nos.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\n# ββ V4.2: Multi-seed support ββββββββββββββββββββββββββββββββββββββββββββββββ\nSEEDS = [42, 123, 456]\nCURRENT_SEED = 42 # β CHANGE THIS PER RUN (42, 123, 456)\n\n# Set all random seeds\nrandom.seed(CURRENT_SEED)\ntorch.manual_seed(CURRENT_SEED)\nif torch.cuda.is_available():\n torch.cuda.manual_seed_all(CURRENT_SEED)\n\n# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nMODEL_ID = \"Polygl0t/Tucano2-qwen-0.5B-Instruct\"\nMAX_SEQ_LENGTH = 2048 # model supports 4096, but 2048 is plenty for Instruct (no <think> overhead)\nMODELS_DIR = Path(\"/home/jupyter/tucano2/models\")\nADAPTER_DIR = MODELS_DIR / f\"tucano2-0.5B-instruct-grpo-v4.2-seed{CURRENT_SEED}\"\nCHECKPOINT_DIR = ADAPTER_DIR / \"checkpoints\"\n\n# ββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nDATA_DIR = Path(\"/home/jupyter/tucano2/data\")\nTRAIN_FILE = DATA_DIR / \"pairs\" / \"train.jsonl\"\nEVAL_FILE = DATA_DIR / \"pairs\" / \"eval.jsonl\" # V4.2: separate eval source\n\n# V4.2: Stratified eval set specification (Change 1)\nEVAL_SAMPLES_PER_TASK = {\n \"extraction\": 20,\n \"sql_qa\": 15,\n \"insights\": 15,\n \"push\": 15,\n}\nEVAL_TOTAL = sum(EVAL_SAMPLES_PER_TASK.values()) # 65\n\n# ββ GRPO Hyperparameters βββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# V4.2 CHANGES: MAX_STEPS 600β1500, EVAL_STEPS 20β50, SAVE_STEPS 50β100\n# Everything else UNCHANGED from V4.1 (validated)\nNUM_GENERATIONS = 16 # 0.5B + short completions = VRAM allows G=16\nMAX_COMPLETION_LENGTH = 512 # Instruct: no <think> overhead\nTEMPERATURE = 1.0 # Skywork-OR1: Ο=1.0 for exploration\nLEARNING_RATE = 5e-6 # V4.1: validated at 5e-6\nBETA = 0.0 # Dr. GRPO Β§3.2: Ξ²=0 optimal for rule-based rewards\nSCALE_REWARDS = False # Dr. GRPO: remove std normalization bias\nBATCH_SIZE = 2 # per-device batch size\nGRAD_ACCUM = 1 # effective batch = 2 * 1 = 2 prompts * 16 gen = 32 completions\nMAX_STEPS = 1500 # V4.2: was 600. ~2.5 full epochs with shuffling\nSAVE_STEPS = 100 # V4.2: was 50. Scaled for longer run\nEVAL_STEPS = 50 # V4.2: was 20. More frequent per-epoch boundary\nEARLY_STOPPING_PATIENCE = 15 # 15 Γ 50 = 750 steps without improvement\nEARLY_STOPPING_DELTA = 0.005\nLR_SCHEDULER_TYPE = \"constant_with_warmup\" # V4.1: validated\nWARMUP_RATIO = 0.05 # V4.1: validated (5% of 1500 = 75 warmup steps)\n\n# ββ LoRA βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nLORA_R = 16\nLORA_ALPHA = 32\n\n# ββ Monitoring βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nWANDB_PROJECT = \"tucano2-commerce\"\nEVAL_MAX_TOKENS = 512 # match training completion length\n\n# ββ Task Classification (inherited from V2/V3) ββββββββββββββββββββββββββββββ\nVALID_SENTIMENTS = {\"positive\", \"negative\", \"neutral\"}\nVALID_CATEGORIES = {\n \"delivery_delay\", \"product_quality\", \"product_not_received\",\n \"wrong_product\", \"seller_communication\", \"app_issue\",\n \"price_value\", \"other\", \"none\",\n}\nVALID_CHURN = {\"low\", \"medium\", \"high\"}\nVALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\nEXTRACTION_FIELDS = [\n \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n \"product_issue\", \"seller_issue\", \"main_complaint\",\n \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n]\n\n# ββ Verified Special Token IDs (from tokenizer_config.json) βββββββββββββββββ\nTOKEN_ID_BOS = 1 # <|im_start|>\nTOKEN_ID_EOS = 2 # <|im_end|>\nTOKEN_ID_PAD = 49109 # <|pad|>\nTOKEN_ID_THINK = 49116 # <think>\nTOKEN_ID_THINK_END = 49117 # </think>\n\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# TASK-AWARE SYSTEM PROMPTS (inherited from V3/V4)\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\nSYSTEM_EXTRACTION = (\n \"VocΓͺ Γ© um motor de extraΓ§Γ£o de dados de e-commerce brasileiro. \"\n \"Retorne APENAS um objeto JSON vΓ‘lido, sem nenhum texto antes ou depois. \"\n \"NΓO USE blocos de cΓ³digo markdown (```json). \"\n \"O primeiro caractere da sua resposta deve ser { e o ΓΊltimo deve ser }. \"\n \"Campos nΓ£o mencionados na avaliaΓ§Γ£o devem ser null β nunca invente valores. \"\n \"Sem explicaΓ§Γ£o. Sem comentΓ‘rios.\"\n)\n\nSYSTEM_SQL = (\n \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n \"Para consultas e anΓ‘lises de dados: apresente a resposta de forma direta \"\n \"com nΓΊmeros e dados concretos. Seja conciso.\"\n)\n\nSYSTEM_INSIGHTS = (\n \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n \"Para anΓ‘lises estratΓ©gicas: raciocine de forma estruturada e concisa, \"\n \"focando nos pontos principais e recomendaΓ§Γ΅es acionΓ‘veis.\"\n)\n\nSYSTEM_PUSH = (\n \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n \"Para notificaΓ§Γ΅es push: seja direto e criativo. \"\n \"A notificaΓ§Γ£o deve ter no mΓ‘ximo 120 caracteres. \"\n \"Responda diretamente.\"\n)\n\nSYSTEM_PT = (\n \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\"\n)\n\ndef get_system_prompt(task_type: str) -> str:\n return {\n \"extraction\": SYSTEM_EXTRACTION,\n \"sql_qa\": SYSTEM_SQL,\n \"insights\": SYSTEM_INSIGHTS,\n \"push\": SYSTEM_PUSH,\n }.get(task_type, SYSTEM_PT)\n\ndef inject_task_system_prompt(msgs, task_type):\n new_msgs = []\n system_prompt = get_system_prompt(task_type)\n has_system = False\n for m in msgs:\n if m[\"role\"] == \"system\":\n new_msgs.append({\"role\": \"system\", \"content\": system_prompt})\n has_system = True\n else:\n new_msgs.append(m)\n if not has_system:\n new_msgs.insert(0, {\"role\": \"system\", \"content\": system_prompt})\n return new_msgs\n\nprint(\"β Task-aware system prompts defined\")\n\nprint(\"β Config loaded\")\nprint(f\" Model: {MODEL_ID}\")\nprint(f\" Seed: {CURRENT_SEED} (run {SEEDS.index(CURRENT_SEED)+1}/{len(SEEDS)})\")\nprint(f\" G={NUM_GENERATIONS}, max_comp={MAX_COMPLETION_LENGTH}, temp={TEMPERATURE}\")\nprint(f\" LR={LEARNING_RATE}, Ξ²={BETA}, scale_rewards={SCALE_REWARDS}\")\nprint(f\" LR schedule: {LR_SCHEDULER_TYPE}, warmup={WARMUP_RATIO}\")\nprint(f\" LoRA r={LORA_R}, Ξ±={LORA_ALPHA}\")\nprint(f\" Max steps: {MAX_STEPS} (~{MAX_STEPS * BATCH_SIZE / 1480:.1f} epochs)\")\nprint(f\" Eval: {EVAL_TOTAL} stratified samples, every {EVAL_STEPS} steps\")\nprint(f\" Save every {SAVE_STEPS} steps\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 4: Load Model + Apply Critical Overrides\n\n**Gate:** Model loaded, `use_cache=True`, `repetition_penalty=1.0`, `temperature=1.0`.\n\nUnchanged from V4.1."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=MODEL_ID,\n max_seq_length=MAX_SEQ_LENGTH,\n load_in_4bit=True,\n dtype=None,\n)\n\nfrom peft import LoraConfig, get_peft_model\n\nlora_config = LoraConfig(\n r=LORA_R,\n lora_alpha=LORA_ALPHA,\n target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\"],\n lora_dropout=0,\n bias=\"none\",\n task_type=\"CAUSAL_LM\",\n)\nmodel = get_peft_model(model, lora_config)\nmodel.print_trainable_parameters()\n\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# CRITICAL OVERRIDES\n# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0\nmodel.generation_config.top_p = 1.0\nmodel.generation_config.max_length = None\n\nif tokenizer.pad_token is None:\n tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"β Model loaded on {model.device}\")\nprint(f\" use_cache: {model.config.use_cache}\")\nprint(f\" temperature: {model.generation_config.temperature}\")\nprint(f\" repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\" top_k: {model.generation_config.top_k}\")\nprint(f\" Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\ntry:\n lm_ptr = model.base_model.model.lm_head.weight.data_ptr()\n embed_ptr = model.base_model.model.model.embed_tokens.weight.data_ptr()\n tied = lm_ptr == embed_ptr\n print(f\" Tied embeddings intact: {tied}\")\n if not tied:\n print(\" β οΈ WARNING: Tied embeddings broken after LoRA patching.\")\nexcept AttributeError as e:\n print(f\" β οΈ Could not check tied embeddings: {e}\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 5: Token ID Verification\n\n**Gate:** All token IDs match. Unchanged from V4.1."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "tok_tests = {\n \"<|im_start|>\": TOKEN_ID_BOS,\n \"<|im_end|>\": TOKEN_ID_EOS,\n \"<|pad|>\": TOKEN_ID_PAD,\n \"<think>\": TOKEN_ID_THINK,\n \"</think>\": TOKEN_ID_THINK_END,\n}\n\nall_pass = True\nfor text, expected_id in tok_tests.items():\n ids = tokenizer.encode(text, add_special_tokens=False)\n actual_id = ids[0] if len(ids) == 1 else ids\n match = (len(ids) == 1 and ids[0] == expected_id)\n status = \"β\" if match else \"β\"\n print(f\" {status} '{text}' β expected {expected_id}, got {actual_id}\")\n if not match:\n all_pass = False\n\nassert all_pass, \"Token ID mismatch detected. Update constants in Cell 3 before proceeding.\"\nprint(\"\\nβ All token IDs verified\")\n\nassert tokenizer.eos_token_id == TOKEN_ID_EOS, f\"eos_token_id mismatch: {tokenizer.eos_token_id}\"\nprint(f\"β eos_token_id = {tokenizer.eos_token_id}\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 6: KV Cache Diagnostic\n\n**Gate:** Ratio < 3Γ β KV cache OK. Unchanged from V4.1."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "FastLanguageModel.for_inference(model)\n\n_kv_msgs = [{\"role\": \"user\", \"content\": \"Qual a categoria de reclamaΓ§Γ£o mais frequente?\"}]\n_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n\n_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\nwith torch.no_grad():\n for _step in range(50):\n _t0 = time.time()\n seq_len = _generated.shape[1]\n if _past is None:\n _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n else:\n _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n _out = model(\n input_ids=_generated[:, -1:] if _past else _generated,\n position_ids=_position_ids,\n attention_mask=torch.ones(1, seq_len, device=model.device),\n past_key_values=_past,\n use_cache=True,\n return_dict=True,\n )\n _past = _out.past_key_values\n _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n _generated = torch.cat([_generated, _next], dim=1)\n _token_times.append(time.time() - _t0)\n\n_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\nprint(f\"First 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\nprint(f\"Last 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\nprint(f\"Ratio last/first: {_ratio:.1f}x\")\nassert _ratio < 5, f\"KV cache BROKEN (ratio {_ratio:.1f}Γ). Check model.config.use_cache.\"\nprint(\"β KV cache working correctly\")\n\ndel _past, _generated, _kv_inputs, _token_times, _out\ngc.collect()\ntorch.cuda.empty_cache()"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected β requires β₯3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) β ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch β we normalize per-component before returning\n- No trainer modification needed β normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4Γ more distinct advantage groups (GDPO Β§3.1)\n\n### Dynamic Task Weights (Change 6) β ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal β larger GRPO advantages β more gradient\n- MT-GRPO IWU Β§3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamaΓ§Γ΅es, taxa, distribuiΓ§Γ£o, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` β rejects floats from PT decimal normalization\n- **Task classifier:** Reordered `_classify_task_type` β insights checked before push to prevent 'reengajamento' misclassification"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json_repair # V4.1: robust JSON parser for LLM output\n",
"\n",
"\n",
"def strip_think(text: str) -> str:\n",
" \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n",
" return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n",
"\n",
"\n",
"def has_think_block(text: str) -> bool:\n",
" return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n",
"\n",
"\n",
"def _classify_task_type(prompt_text: str) -> str:\n",
" \"\"\"V4.2.1: reordered β insights before push to prevent misclassification.\n",
" \n",
" \"notificaΓ§Γ£o de reengajamento\" in a customer profile context is insights,\n",
" not push. Check insights keywords first.\n",
" \"\"\"\n",
" p = prompt_text.lower()\n",
" # 1. Insights FIRST β customer profile questions mentioning reengagement are insights\n",
" if \"perfil do cliente\" in p or \"retenΓ§Γ£o\" in p or \"anΓ‘lise\" in p or \"insight\" in p:\n",
" return \"insights\"\n",
" # 2. Extraction\n",
" elif \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n",
" return \"extraction\"\n",
" # 3. Push β only after insights is ruled out\n",
" elif \"notificaΓ§Γ£o push\" in p or \"notificaΓ§Γ£o de reengajamento\" in p:\n",
" return \"push\"\n",
" else:\n",
" return \"sql_qa\"\n",
"\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# V4.1: ROBUST JSON PARSER (unchanged)\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"def _normalize_pt_decimals(s: str) -> str:\n",
" \"\"\"Convert PT-BR decimals (4,5) to JSON-valid (4.5), only outside quoted strings.\"\"\"\n",
" result, in_string, escape_next = [], False, False\n",
" i = 0\n",
" while i < len(s):\n",
" c = s[i]\n",
" if escape_next:\n",
" result.append(c); escape_next = False; i += 1; continue\n",
" if c == '\\\\' and in_string:\n",
" result.append(c); escape_next = True; i += 1; continue\n",
" if c == '\"':\n",
" in_string = not in_string; result.append(c); i += 1; continue\n",
" if not in_string:\n",
" m = re.match(r'(\\d+),(\\d+)', s[i:])\n",
" if m:\n",
" result.append(m.group(1) + '.' + m.group(2))\n",
" i += len(m.group(0)); continue\n",
" result.append(c); i += 1\n",
" return ''.join(result)\n",
"\n",
"\n",
"def _extract_json(text: str) -> dict | None:\n",
" \"\"\"Robust JSON extraction for Portuguese LLM output.\"\"\"\n",
" stripped = re.sub(r'^```(?:json)?\\s*|\\s*```$', '', text.strip(), flags=re.MULTILINE).strip()\n",
" for attempt in [stripped, _normalize_pt_decimals(stripped)]:\n",
" try:\n",
" result = json.loads(attempt)\n",
" if isinstance(result, dict):\n",
" return result\n",
" except (json.JSONDecodeError, TypeError):\n",
" pass\n",
" normalized = _normalize_pt_decimals(stripped)\n",
" try:\n",
" result = json_repair.repair_json(normalized, return_objects=True)\n",
" if isinstance(result, dict):\n",
" return result\n",
" except Exception:\n",
" pass\n",
" try:\n",
" result = json_repair.repair_json(stripped, return_objects=True)\n",
" if isinstance(result, dict):\n",
" return result\n",
" except Exception:\n",
" pass\n",
" return None\n",
"\n",
"\n",
"def reward_extraction(completion: str, prompt_text: str = \"\") -> float:\n",
" \"\"\"Continuous reward for extraction tasks (max 1.0).\"\"\"\n",
" answer = strip_think(completion)\n",
" data = _extract_json(answer)\n",
"\n",
" if data is None:\n",
" if \"{\" in answer and \"}\" in answer:\n",
" return 0.05\n",
" return 0.0\n",
"\n",
" if not isinstance(data, dict):\n",
" return 0.1\n",
"\n",
" score = 0.3 # valid JSON object\n",
"\n",
" # Schema completeness (0.3 total)\n",
" present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n",
" score += 0.3 * (present / len(EXTRACTION_FIELDS))\n",
"\n",
" # Value validity (0.4 total)\n",
" checks_passed = 0\n",
" checks_total = 0\n",
"\n",
" for field, validator in [\n",
" (\"sentiment\", lambda v: isinstance(v, str) and v in VALID_SENTIMENTS),\n",
" (\"complaint_category\", lambda v: isinstance(v, str) and v in VALID_CATEGORIES),\n",
" (\"churn_risk\", lambda v: isinstance(v, str) and v in VALID_CHURN),\n",
" (\"repeat_intent\", lambda v: isinstance(v, str) and v in VALID_REPEAT),\n",
" # V4.2.1: must be int, not float/bool. PT normalizer turns \"0,5\"β0.5 (float)\n",
" (\"sentiment_score\", lambda v: isinstance(v, int) and not isinstance(v, bool) and 1 <= v <= 5),\n",
" ]:\n",
" checks_total += 1\n",
" if field in data and validator(data[field]):\n",
" checks_passed += 1\n",
"\n",
" for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
" checks_total += 1\n",
" if bool_field in data and isinstance(data[bool_field], bool):\n",
" checks_passed += 1\n",
"\n",
" if checks_total > 0:\n",
" score += 0.4 * (checks_passed / checks_total)\n",
"\n",
" # nota=1-2 on a 5-star scale β negative review; nota=4-5 β positive.\n",
" # Penalize clear sentiment mismatches to break reward hacking.\n",
" import re as _re\n",
" nota_match = _re.search(r\"nota=(\\d)/5\", prompt_text)\n",
" if nota_match and \"sentiment\" in data:\n",
" nota = int(nota_match.group(1))\n",
" sentiment = data.get(\"sentiment\", \"\")\n",
" if nota <= 2 and sentiment == \"positive\":\n",
" score -= 0.20\n",
" elif nota >= 4 and sentiment == \"negative\":\n",
" score -= 0.20\n",
"\n",
" return max(0.0, min(score, 1.0))\n",
"\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# V4.2: SQL REWARD V2 β Validation-aware (Change 3)\n",
"# Replaces heuristic vocabulary matching with structural analysis.\n",
"# Expected: distinguishes \"mentions SQL keywords\" from \"produces correct answer\"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"def reward_sql_qa(completion: str) -> float:\n",
" \"\"\"V4.2: Validation-aware SQL Q&A reward (max 1.0).\n",
" \n",
" Tier 1 (0.30): SQL structure detected (β₯3 keywords or code block)\n",
" Tier 2 (0.25): Answer has both query and explanation\n",
" Tier 3 (0.25): Numerical specificity (concrete data)\n",
" Tier 4 (0.20): Portuguese business domain coherence\n",
" \"\"\"\n",
" answer = strip_think(completion)\n",
" if not answer.strip():\n",
" return 0.0\n",
"\n",
" score = 0.0\n",
"\n",
" # Tier 1 (0.30): SQL structure detected\n",
" sql_keywords = [\"SELECT\", \"FROM\", \"WHERE\", \"GROUP BY\", \"ORDER BY\",\n",
" \"JOIN\", \"HAVING\", \"COUNT\", \"AVG\", \"SUM\"]\n",
" sql_found = sum(1 for kw in sql_keywords if kw in answer.upper())\n",
" if sql_found >= 3:\n",
" score += 0.30\n",
" elif sql_found >= 1:\n",
" score += 0.15\n",
"\n",
" # Tier 2 (0.25): Answer has both query AND explanation\n",
" has_query = bool(re.search(r\"```sql|SELECT.{5,}FROM\", answer, re.IGNORECASE | re.DOTALL))\n",
" has_answer = any(w in answer.lower() for w in [\"resultado\", \"total\", \"mΓ©dia\", \"mostra\", \"portanto\"])\n",
" if has_query and has_answer:\n",
" score += 0.25\n",
" elif has_query or has_answer:\n",
" score += 0.12\n",
"\n",
" # Tier 3 (0.25): Numerical specificity\n",
" numbers = re.findall(r\"\\d+(?:[.,]\\d+)?(?:\\s*%)?\", answer)\n",
" score += min(0.25, 0.05 * len(numbers))\n",
"\n",
" # Tier 4 (0.20): Portuguese business domain coherence β EXPANDED (V4.2.1)\n",
" pt_domain = [\n",
" # Original 10\n",
" \"pedidos\", \"clientes\", \"vendedores\", \"produtos\", \"avaliaΓ§Γ£o\",\n",
" \"entrega\", \"reclamaΓ§Γ£o\", \"satisfaΓ§Γ£o\", \"categoria\", \"perΓodo\",\n",
" # V4.2.1: broader e-commerce vocabulary (Cell 8 audit: samples 6, 10)\n",
" \"compradores\", \"sentimentos\", \"reclamaΓ§Γ΅es\", \"taxa\", \"distribuiΓ§Γ£o\",\n",
" \"vendas\", \"faturamento\", \"estoque\", \"logΓstica\", \"marketplace\",\n",
" \"consumidores\", \"fornecedores\", \"devoluΓ§Γ΅es\", \"reembolso\", \"frete\",\n",
" \"pagamento\", \"cancelamento\", \"atraso\", \"qualidade\", \"nota\",\n",
" \"positivos\", \"negativos\", \"neutros\", \"tendΓͺncia\", \"desempenho\",\n",
" ]\n",
" score += min(0.20, 0.04 * sum(1 for w in pt_domain if w in answer.lower()))\n",
"\n",
" return min(score, 1.0)\n",
"\n",
"\n",
"def reward_insights(completion: str) -> float:\n",
" \"\"\"Continuous reward for insights (max 1.0). Unchanged from V4.1.\"\"\"\n",
" answer = strip_think(completion)\n",
" if not answer.strip():\n",
" return 0.0\n",
"\n",
" score = 0.0\n",
"\n",
" action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n",
" \"priorizar\", \"investir\", \"otimizar\", \"estratΓ©gi\", \"aΓ§Γ£o\"]\n",
" matches = sum(1 for w in action_words if w in answer.lower())\n",
" score += min(0.4, 0.08 * matches)\n",
"\n",
" length = len(answer)\n",
" if 100 <= length <= 800:\n",
" score += 0.3\n",
" elif length > 0:\n",
" score += 0.3 * max(0, 1 - abs(length - 450) / 450)\n",
"\n",
" structure_marks = len(re.findall(r\"^[-β’*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n",
" score += min(0.2, 0.04 * structure_marks)\n",
"\n",
" if any(w in answer.lower() for w in [\"cliente\", \"produto\", \"serviΓ§o\", \"empresa\"]):\n",
" score += 0.1\n",
"\n",
" return min(score, 1.0)\n",
"\n",
"\n",
"def reward_push(completion: str) -> float:\n",
" \"\"\"Continuous reward for push notifications (max 1.0).\n",
" \n",
" V4.2.1 fixes (Cell 8 audit):\n",
" - Steep length penalty: hard zero above 200 chars (was linear decay to 240)\n",
" - Formal email penalty: -0.20 for \"Prezado\"/\"Atenciosamente\"/etc.\n",
" \"\"\"\n",
" answer = strip_think(completion).strip()\n",
" if not answer:\n",
" return 0.0\n",
"\n",
" length = len(answer)\n",
"\n",
" # Length score (0.50 max) β steep decay above 120 chars\n",
" if length <= 120:\n",
" length_score = 0.50\n",
" elif length <= 160:\n",
" length_score = 0.50 - 0.40 * ((length - 120) / 40) # 0.50 β 0.10\n",
" elif length <= 200:\n",
" length_score = 0.10 - 0.10 * ((length - 160) / 40) # 0.10 β 0.00\n",
" else:\n",
" length_score = 0.0\n",
"\n",
" pt_markers = re.findall(r\"[ãçéΓͺΓ³ΓΊΓ’Γ΅]|vocΓͺ|para|como|seu|sua|oferta|desconto|produto\",\n",
" answer, re.IGNORECASE)\n",
" lang_score = min(0.3, 0.03 * len(pt_markers))\n",
"\n",
" generic = [\"olΓ‘\", \"obrigado pela compra\", \"agradecemos\"]\n",
" is_generic = any(g in answer.lower() for g in generic)\n",
" creativity_score = 0.0 if is_generic else 0.2\n",
"\n",
" # Formal email penalty β push notifications should NOT be formal emails\n",
" formal_markers = [\n",
" \"prezado\", \"prezada\", \"atenciosamente\", \"cordialmente\",\n",
" \"att,\", \"att.\", \"respeitosamente\", \"caro cliente\", \"cara cliente\",\n",
" ]\n",
" has_formal = any(fm in answer.lower() for fm in formal_markers)\n",
" formal_penalty = -0.20 if has_formal else 0.0\n",
"\n",
" return max(0.0, min(length_score + lang_score + creativity_score + formal_penalty, 1.0))\n",
"\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# V4.2: GDPO PER-COMPONENT NORMALIZATION (Change 5)\n",
"# Normalize each reward component independently before aggregation.\n",
"# GDPO (2601.05242) Β§3.1: preserves ~4Γ more distinct advantage groups.\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"def gdpo_normalize(component_rewards: dict) -> list:\n",
" \"\"\"Per-component normalization before aggregation (GDPO 2601.05242 Β§3.1).\n",
" \n",
" Args:\n",
" component_rewards: {task_name: [reward_per_sample, ...]} for each component\n",
" \n",
" Returns:\n",
" List of normalized summed rewards, one per sample.\n",
" \"\"\"\n",
" normalized = {}\n",
" for task, rewards in component_rewards.items():\n",
" rewards_t = torch.tensor(rewards, dtype=torch.float32)\n",
" std = rewards_t.std()\n",
" if std > 1e-8:\n",
" normalized[task] = ((rewards_t - rewards_t.mean()) / std).tolist()\n",
" else:\n",
" normalized[task] = [0.0] * len(rewards) # zero-variance group\n",
" # Sum normalized components per sample\n",
" n = len(next(iter(normalized.values())))\n",
" return [sum(normalized[t][i] for t in normalized) for i in range(n)]\n",
"\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# V4.2: DYNAMIC TASK WEIGHTING β MT-GRPO IWU (Change 6)\n",
"# Track per-task reward improvement rates, upweight stagnating tasks.\n",
"# MT-GRPO (2602.05547) Β§3.2: prevents easy-task collapse.\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"_task_weights = {\n",
" \"extraction\": 0.40, # matches training data distribution (40%)\n",
" \"sql_qa\": 0.40, # matches training data distribution (40%)\n",
" \"insights\": 0.10, # matches training data distribution (10%)\n",
" \"push\": 0.10, # matches training data distribution (10%)\n",
"}\n",
"_task_reward_history = {t: [] for t in _task_weights}\n",
"\n",
"def update_task_weights(step: int, per_task_rewards: dict, update_interval: int = 50):\n",
" \"\"\"MT-GRPO IWU: update task sampling weights based on improvement rate.\n",
" \n",
" Args:\n",
" step: Current training step\n",
" per_task_rewards: {task: mean_reward} from latest eval\n",
" update_interval: Only update every N steps\n",
" \"\"\"\n",
" global _task_weights\n",
" if step % update_interval != 0 or step == 0:\n",
" return\n",
" \n",
" for task, reward in per_task_rewards.items():\n",
" if task not in _task_reward_history:\n",
" continue\n",
" _task_reward_history[task].append(reward)\n",
" if len(_task_reward_history[task]) >= 2:\n",
" improvement = _task_reward_history[task][-1] - _task_reward_history[task][-2]\n",
" if improvement < 0.01: # stagnating\n",
" _task_weights[task] = min(0.50, _task_weights[task] * 1.3)\n",
" elif improvement > 0.05: # improving fast\n",
" _task_weights[task] = max(0.10, _task_weights[task] * 0.85)\n",
" \n",
" # Normalize to sum to 1\n",
" total = sum(_task_weights.values())\n",
" _task_weights = {t: w / total for t, w in _task_weights.items()}\n",
"\n",
"\n",
"def get_task_weighted_indices(dataset, n_samples: int) -> list:\n",
" \"\"\"Sample indices with probability proportional to task weights.\"\"\"\n",
" task_indices = {t: [] for t in _task_weights}\n",
" for i, record in enumerate(dataset):\n",
" user_txt = \" \".join(m[\"content\"] for m in record[\"prompt\"] if m[\"role\"] == \"user\")\n",
" task = _classify_task_type(user_txt)\n",
" if task in task_indices:\n",
" task_indices[task].append(i)\n",
" \n",
" sampled = []\n",
" for task, weight in _task_weights.items():\n",
" n = max(1, int(n_samples * weight))\n",
" pool = task_indices.get(task, [])\n",
" if pool:\n",
" sampled.extend(random.sample(pool, min(n, len(pool))))\n",
" random.shuffle(sampled)\n",
" return sampled[:n_samples]\n",
"\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# MASTER REWARD FUNCTION β V4.2: returns per-component rewards for GDPO\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"def commerce_reward_fn(completions, prompts, **kwargs) -> list:\n",
" \"\"\"Master reward function with GDPO normalization + dynamic task weighting.\n",
" \n",
" V4.2 integration with TRL 0.24.0:\n",
" TRL calls this once per step with the full batch (batch_size Γ G completions).\n",
" We exploit this to apply batch-level per-component normalization (GDPO Β§3.1)\n",
" and dynamic task weighting (MT-GRPO IWU Β§3.2) INSIDE the reward function,\n",
" so the trainer receives pre-normalized, weighted rewards without modification.\n",
" \n",
" Pipeline:\n",
" 1. Score each completion with its task-specific reward function (raw)\n",
" 2. Group raw rewards by task type\n",
" 3. GDPO: z-score normalize each task group independently\n",
" 4. IWU: multiply normalized rewards by current _task_weights\n",
" 5. Shift back to [0, 1] range (GRPO with scale_rewards=False expects non-negative)\n",
" 6. Return flat list in original sample order\n",
" \"\"\"\n",
" n = len(completions)\n",
" raw_rewards = [0.0] * n\n",
" task_labels = [\"\"] * n\n",
" \n",
" # ββ Step 1: Compute raw per-sample rewards ββββββββββββββββββββββββββββββ\n",
" for i, (completion, prompt) in enumerate(zip(completions, prompts)):\n",
" if isinstance(completion, list):\n",
" comp_text = completion[-1][\"content\"] if completion else \"\"\n",
" else:\n",
" comp_text = str(completion)\n",
"\n",
" if isinstance(prompt, list):\n",
" prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
" else:\n",
" prompt_text = str(prompt)\n",
"\n",
" task = _classify_task_type(prompt_text)\n",
" task_labels[i] = task\n",
"\n",
" if task == \"extraction\":\n",
" raw_rewards[i] = reward_extraction(comp_text, prompt_text)\n",
" elif task == \"sql_qa\":\n",
" raw_rewards[i] = reward_sql_qa(comp_text)\n",
" elif task == \"insights\":\n",
" raw_rewards[i] = reward_insights(comp_text)\n",
" elif task == \"push\":\n",
" raw_rewards[i] = reward_push(comp_text)\n",
" else:\n",
" raw_rewards[i] = 0.2 if comp_text.strip() else 0.0\n",
"\n",
" # ββ Step 2-4: GDPO per-component normalization + IWU weighting ββββββββββ\n",
" # Group indices by task\n",
" task_indices = {}\n",
" for i, task in enumerate(task_labels):\n",
" if task not in task_indices:\n",
" task_indices[task] = []\n",
" task_indices[task].append(i)\n",
" \n",
" final_rewards = [0.0] * n\n",
" \n",
" for task, indices in task_indices.items():\n",
" task_raw = [raw_rewards[i] for i in indices]\n",
" \n",
" # GDPO: z-score normalize within this task group\n",
" if len(task_raw) > 1:\n",
" t_mean = sum(task_raw) / len(task_raw)\n",
" t_var = sum((r - t_mean) ** 2 for r in task_raw) / (len(task_raw) - 1)\n",
" t_std = t_var ** 0.5\n",
" if t_std > 1e-8:\n",
" normed = [(r - t_mean) / t_std for r in task_raw]\n",
" else:\n",
" normed = [0.0] * len(task_raw)\n",
" else:\n",
" # Single sample in this task group β can't normalize, use raw\n",
" normed = [0.0]\n",
" \n",
" # IWU: scale by dynamic task weight\n",
" weight = _task_weights.get(task, 0.25)\n",
" weighted = [v * weight for v in normed]\n",
" \n",
" for idx_in_group, global_idx in enumerate(indices):\n",
" final_rewards[global_idx] = weighted[idx_in_group]\n",
" \n",
" # ββ Step 5: Shift to non-negative range βββββββββββββββββββββββββββββββββ\n",
" # GRPO with scale_rewards=False computes advantages as reward - mean(rewards).\n",
" # Normalized rewards are already zero-centered per-task, so the advantage\n",
" # computation will work correctly. But TRL may log negative rewards as warnings.\n",
" # Shift so minimum is 0 to keep logging clean, without changing advantage ordering.\n",
" min_r = min(final_rewards) if final_rewards else 0.0\n",
" if min_r < 0:\n",
" final_rewards = [r - min_r for r in final_rewards]\n",
" \n",
" return final_rewards\n",
"\n",
"\n",
"def commerce_reward_fn_raw(completions, prompts, **kwargs) -> list:\n",
" \"\"\"Raw reward function WITHOUT GDPO/IWU β used for eval metrics.\n",
" \n",
" Eval should report raw task-specific rewards for interpretability.\n",
" The GDPO+IWU normalization is only for shaping the training gradient signal.\n",
" \"\"\"\n",
" rewards = []\n",
" for completion, prompt in zip(completions, prompts):\n",
" if isinstance(completion, list):\n",
" comp_text = completion[-1][\"content\"] if completion else \"\"\n",
" else:\n",
" comp_text = str(completion)\n",
"\n",
" if isinstance(prompt, list):\n",
" prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
" else:\n",
" prompt_text = str(prompt)\n",
"\n",
" task = _classify_task_type(prompt_text)\n",
"\n",
" if task == \"extraction\":\n",
" rewards.append(reward_extraction(comp_text, prompt_text))\n",
" elif task == \"sql_qa\":\n",
" rewards.append(reward_sql_qa(comp_text))\n",
" elif task == \"insights\":\n",
" rewards.append(reward_insights(comp_text))\n",
" elif task == \"push\":\n",
" rewards.append(reward_push(comp_text))\n",
" else:\n",
" r = 0.2 if comp_text.strip() else 0.0\n",
" rewards.append(r)\n",
" return rewards\n",
"\n",
"\n",
"print(\"β Reward functions defined (V4.2: SQL v2 + GDPO active + IWU active)\")\n",
"print(f\" Task weights: {_task_weights}\")\n",
"print(f\" commerce_reward_fn: GDPO+IWU normalized (for training)\")\n",
"print(f\" commerce_reward_fn_raw: raw scores (for eval/audit)\")\n",
"print(f\" Task weights: {_task_weights}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 8: Reward Function Audit (Change 2)\n\n**V4.2 addition: 30-minute audit protocol.**\n\nGenerate 20 completions (5 per task) at temp=0.1, score them with the reward function,\nthen have the human score them 0-10. Compute Spearman Ο.\n\n**Gate:** Ο > 0.70. If below, reward function is miscalibrated β fix before training.\n\n**Why:** The V1-V4 parser bug would have been caught in 30 minutes with this protocol.\n\n### Instructions\n1. Run this cell β it generates 20 completions and scores them automatically\n2. For each completion: read the FULL output (no truncation), then enter your 0-10 score at the prompt\n3. After all 20 scores, the cell computes Spearman Ο automatically\n4. If Ο < 0.70, investigate discrepancies (marked β οΈ) before proceeding to training"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import spearmanr\n",
"\n",
"AUDIT_PROMPTS_PER_TASK = 5\n",
"\n",
"# ββ Collect audit prompts (5 per task) βββββββββββββββββββββββββββββββββββββββ\n",
"audit_by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
"with open(TRAIN_FILE) as f:\n",
" for line in f:\n",
" row = json.loads(line)\n",
" convs = row[\"conversations\"]\n",
" prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
" if not prompt_msgs:\n",
" continue\n",
" user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
" task = _classify_task_type(user_text)\n",
" if len(audit_by_type[task]) < AUDIT_PROMPTS_PER_TASK:\n",
" audit_by_type[task].append(prompt_msgs)\n",
"\n",
"print(f\"Audit prompts collected: {', '.join(f'{k}={len(v)}' for k, v in audit_by_type.items())}\")\n",
"\n",
"# ββ Generate completions and score automatically βββββββββββββββββββββββββββββ\n",
"FastLanguageModel.for_inference(model)\n",
"\n",
"audit_auto_scores = []\n",
"audit_tasks = []\n",
"audit_completions = []\n",
"\n",
"audit_prompts_text = [] # store original user message for display\n",
"\n",
"for task_type in [\"extraction\", \"sql_qa\", \"insights\", \"push\"]:\n",
" for msgs in audit_by_type[task_type]:\n",
" # Extract original user message BEFORE injecting system prompt\n",
" user_content = \"\\n\".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n",
" audit_prompts_text.append(user_content)\n",
" \n",
" msgs = inject_task_system_prompt(msgs, task_type)\n",
" text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
" with torch.no_grad():\n",
" out = model.generate(\n",
" **inputs,\n",
" max_new_tokens=MAX_COMPLETION_LENGTH,\n",
" temperature=0.1, # near-deterministic for audit\n",
" do_sample=True,\n",
" repetition_penalty=1.0,\n",
" )\n",
" resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
" r = commerce_reward_fn_raw([resp], [text])[0] # Raw rewards for audit (not GDPO-normalized)\n",
" audit_auto_scores.append(r)\n",
" audit_tasks.append(task_type)\n",
" audit_completions.append(resp)\n",
"\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"# INTERACTIVE REWARD AUDIT\n",
"# Shows each completion in FULL (no truncation), prompts for a 0-10 score.\n",
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"\n",
"print(f\"\\n{'='*80}\")\n",
"print(\"REWARD FUNCTION AUDIT β 20 Completions (interactive scoring)\")\n",
"print(\"Score each completion 0-10: 0=garbage, 5=acceptable, 10=perfect\")\n",
"print(f\"{'='*80}\")\n",
"\n",
"audit_human_scores = []\n",
"\n",
"for i, (task, auto_r, comp, prompt_txt) in enumerate(zip(audit_tasks, audit_auto_scores, audit_completions, audit_prompts_text)):\n",
" answer = strip_think(comp) # full completion, no truncation\n",
" print(f\"\\n{'β'*80}\")\n",
" print(f\" Sample {i+1}/{len(audit_auto_scores)} [{task}] auto_reward={auto_r:.3f}\")\n",
" print(f\"{'β'*80}\")\n",
" print(f\"\\nINPUT REVIEW:\\n{prompt_txt}\\n\")\n",
" print(f\"MODEL OUTPUT:\\n{answer}\")\n",
" print()\n",
" while True:\n",
" try:\n",
" score = float(input(f\" Your score (0-10): \"))\n",
" if 0 <= score <= 10:\n",
" break\n",
" print(\" β οΈ Score must be between 0 and 10\")\n",
" except (ValueError, EOFError):\n",
" print(\" β οΈ Enter a number between 0 and 10\")\n",
" audit_human_scores.append(score)\n",
" print(f\" β Recorded: human={score:.0f}, auto={auto_r:.3f}\")\n",
"\n",
"# ββ Compute Spearman Ο βββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
"human_normalized = [s / 10.0 for s in audit_human_scores]\n",
"rho, p_value = spearmanr(human_normalized, audit_auto_scores)\n",
"\n",
"print(f\"\\n{'='*80}\")\n",
"print(f\"AUDIT RESULTS\")\n",
"print(f\"{'='*80}\")\n",
"print(f\" Spearman Ο = {rho:.3f} (p = {p_value:.4f})\")\n",
"print()\n",
"print(f\" {'#':>3s} {'Task':12s} {'Human':>6s} {'Auto':>6s} {'Ξ':>6s}\")\n",
"print(f\" {'β'*40}\")\n",
"for i, (task, h, a) in enumerate(zip(audit_tasks, human_normalized, audit_auto_scores)):\n",
" delta = abs(h - a)\n",
" flag = \" β οΈ\" if delta > 0.3 else \"\"\n",
" print(f\" {i+1:3d} {task:12s} {h:6.2f} {a:6.3f} {delta:6.3f}{flag}\")\n",
"\n",
"if rho > 0.70:\n",
" print(f\"\\n β
PASS: Ο={rho:.3f} > 0.70 β reward function is calibrated\")\n",
"else:\n",
" print(f\"\\n β FAIL: Ο={rho:.3f} < 0.70 β reward function is miscalibrated\")\n",
" print(\" β Investigate samples marked β οΈ before training. Check:\")\n",
" print(\" 1. Is the JSON parser handling all output formats?\")\n",
" print(\" 2. Are SQL reward tiers appropriate for this model's output style?\")\n",
" print(\" 3. Are insights/push length penalties calibrated?\")\n",
"\n",
"assert rho > 0.65, f\"Reward function miscalibrated (Ο={rho:.3f} < 0.65). Fix before training.\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 9: Build Stratified Eval Set (Change 1)\n\n**V4.2: 65 stratified samples** (20 extraction + 15 sql_qa + 15 insights + 15 push).\n\nSampled from `data/pairs/eval.jsonl`, saved as `data/pairs/eval_v2_stratified.jsonl`.\n**This file is fixed across all seeds.** Never resample.\n\n**Gate:** Exactly 65 samples with correct per-task counts."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "eval_v2_stratified_path = DATA_DIR / \"pairs\" / \"eval_v2_stratified.jsonl\"\n\n# ββ Check if already built (idempotent across seeds) ββββββββββββββββββββββββ\nif eval_v2_stratified_path.exists():\n existing = []\n with open(eval_v2_stratified_path) as f:\n for line in f:\n existing.append(json.loads(line))\n # Verify counts\n task_counts = {}\n for rec in existing:\n task_counts[rec[\"task_type\"]] = task_counts.get(rec[\"task_type\"], 0) + 1\n print(f\"β Stratified eval set already exists: {eval_v2_stratified_path}\")\n print(f\" Counts: {task_counts}\")\n print(f\" Total: {len(existing)}\")\n assert len(existing) == EVAL_TOTAL, f\"Expected {EVAL_TOTAL}, got {len(existing)}\"\nelse:\n # ββ Build from eval.jsonl (or train.jsonl fallback) ββββββββββββββββββββββ\n eval_source = EVAL_FILE if EVAL_FILE.exists() else TRAIN_FILE\n print(f\"Building stratified eval set from: {eval_source}\")\n \n # Collect all records by task\n eval_by_task = {t: [] for t in EVAL_SAMPLES_PER_TASK}\n with open(eval_source) as f:\n for line in f:\n row = json.loads(line)\n convs = row[\"conversations\"]\n prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n if not prompt_msgs:\n continue\n user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n if task in eval_by_task:\n eval_by_task[task].append({\n \"conversations\": convs,\n \"prompt_msgs\": prompt_msgs,\n \"task_type\": task,\n })\n \n print(f\" Available: {', '.join(f'{k}={len(v)}' for k, v in eval_by_task.items())}\")\n \n # Stratified sampling with FIXED seed (42 always, regardless of CURRENT_SEED)\n eval_rng = random.Random(42)\n stratified = []\n for task, target_n in EVAL_SAMPLES_PER_TASK.items():\n pool = eval_by_task[task]\n if len(pool) < target_n:\n print(f\" β οΈ {task}: only {len(pool)} available, wanted {target_n}. Using all.\")\n sampled = pool\n else:\n sampled = eval_rng.sample(pool, target_n)\n stratified.extend(sampled)\n \n # Save as JSONL\n eval_v2_stratified_path.parent.mkdir(parents=True, exist_ok=True)\n with open(eval_v2_stratified_path, \"w\") as f:\n for rec in stratified:\n f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n \n # Verify\n task_counts = {}\n for rec in stratified:\n task_counts[rec[\"task_type\"]] = task_counts.get(rec[\"task_type\"], 0) + 1\n \n print(f\"\\nβ Stratified eval set saved: {eval_v2_stratified_path}\")\n print(f\" Counts: {task_counts}\")\n print(f\" Total: {len(stratified)}\")\n\nprint(f\"\\nExpected: {EVAL_SAMPLES_PER_TASK} = {EVAL_TOTAL} total\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 10: Dataset Preparation + DynamicTaskSampler Init\n\n**V4.2 changes:**\n- Eval loaded from `eval_v2_stratified.jsonl` (65 fixed samples) instead of random split\n- Train still from `train.jsonl` with task-aware system prompt injection\n- DynamicTaskSampler initialized for IWU (Change 6)\n\n**Gate:** Train has ~1,480+ prompts. Eval has exactly 65 stratified samples. All 4 task types present in both."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from datasets import Dataset\n\ndef prepare_datasets_v42(train_file, eval_stratified_file, seed=CURRENT_SEED):\n \"\"\"V4.2: Load train from JSONL + eval from stratified file.\"\"\"\n rng = random.Random(seed)\n\n # ββ Train set ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n train_records = []\n with open(train_file) as f:\n for line in f:\n row = json.loads(line)\n convs = row[\"conversations\"]\n prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n if prompt_msgs:\n train_records.append(prompt_msgs)\n rng.shuffle(train_records)\n \n # Inject task-aware system prompts\n for i, msgs in enumerate(train_records):\n user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n train_records[i] = inject_task_system_prompt(msgs, task)\n \n # ββ Eval set (V4.2: from stratified file, fixed across seeds) ββββββββββββ\n eval_records = []\n with open(eval_stratified_file) as f:\n for line in f:\n rec = json.loads(line)\n prompt_msgs = rec[\"prompt_msgs\"]\n user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n eval_records.append(inject_task_system_prompt(prompt_msgs, task))\n \n print(\" β Task-aware system prompts injected\")\n \n # Print distributions\n for label, records in [(\"train\", train_records), (\"eval\", eval_records)]:\n dist = {}\n for msgs in records:\n user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n dist[task] = dist.get(task, 0) + 1\n print(f\" {label}: {len(records)} prompts β {dist}\")\n \n train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n return train_ds, eval_ds\n\ntrain_dataset, eval_dataset = prepare_datasets_v42(TRAIN_FILE, eval_v2_stratified_path)\nprint(f\"\\nβ Datasets: train={len(train_dataset)}, eval={len(eval_dataset)}\")\nassert len(eval_dataset) == EVAL_TOTAL, f\"Expected {EVAL_TOTAL} eval samples, got {len(eval_dataset)}\"\nprint(f\"β Eval is stratified: {EVAL_TOTAL} samples (fixed across seeds)\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 11: Smoke Test (1 Step)\n\n**Gate:** No OOM. Peak VRAM < 20GB. Step time < 180s."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "from trl import GRPOConfig, GRPOTrainer\n\nFastLanguageModel.for_training(model)\n\nsmoke_config = GRPOConfig(\n output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n num_generations=NUM_GENERATIONS,\n scale_rewards=SCALE_REWARDS,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=1,\n temperature=TEMPERATURE,\n beta=BETA,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=1,\n learning_rate=LEARNING_RATE,\n lr_scheduler_type=LR_SCHEDULER_TYPE,\n warmup_ratio=WARMUP_RATIO,\n fp16=False,\n bf16=True,\n logging_steps=1,\n save_steps=999,\n report_to=\"none\",\n max_prompt_length=MAX_SEQ_LENGTH // 2,\n seed=CURRENT_SEED,\n remove_unused_columns=False,\n)\n\n\nclass UnslothGRPOTrainer(GRPOTrainer):\n def _generate(self, prompts, images):\n FastLanguageModel.for_inference(self.model)\n try:\n result = super()._generate(prompts, images)\n finally:\n FastLanguageModel.for_training(self.model)\n return result\n\n\nsmoke_trainer = UnslothGRPOTrainer(\n model=model,\n reward_funcs=commerce_reward_fn,\n args=smoke_config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n)\n\nt0 = time.time()\nsmoke_trainer.train()\nstep_time = time.time() - t0\n\npeak_vram = torch.cuda.max_memory_allocated() / 1e9\nprint(f\"\\nβ Smoke test passed!\")\nprint(f\" Step time: {step_time:.0f}s\")\nprint(f\" Peak VRAM: {peak_vram:.1f}GB / {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB\")\nprint(f\" Estimated full run ({MAX_STEPS} steps): {step_time * MAX_STEPS / 3600:.1f}h\")\n\ndel smoke_trainer\ngc.collect(); torch.cuda.empty_cache()"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 12: Probe Run (10 Steps)\n\n**V4.2:** Uses `CURRENT_SEED` for reproducibility. No hard clip_ratio gate (expected=0 for LoRA)."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"FastLanguageModel.for_training(model)\n",
"\n",
"probe_config = GRPOConfig(\n",
" output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
" num_generations=NUM_GENERATIONS,\n",
" scale_rewards=SCALE_REWARDS,\n",
" max_completion_length=MAX_COMPLETION_LENGTH,\n",
" max_steps=10,\n",
" temperature=TEMPERATURE,\n",
" beta=BETA,\n",
" num_train_epochs=1,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" gradient_accumulation_steps=GRAD_ACCUM,\n",
" learning_rate=LEARNING_RATE,\n",
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
" warmup_ratio=WARMUP_RATIO,\n",
" fp16=False,\n",
" bf16=True,\n",
" logging_steps=1,\n",
" save_steps=999,\n",
" report_to=\"none\",\n",
" max_prompt_length=MAX_SEQ_LENGTH // 2,\n",
" seed=CURRENT_SEED,\n",
" remove_unused_columns=False,\n",
")\n",
"\n",
"probe_trainer = UnslothGRPOTrainer(\n",
" model=model,\n",
" reward_funcs=commerce_reward_fn,\n",
" args=probe_config,\n",
" train_dataset=train_dataset,\n",
" processing_class=tokenizer,\n",
")\n",
"\n",
"t0 = time.time()\n",
"result = probe_trainer.train()\n",
"elapsed = time.time() - t0\n",
"\n",
"# ββ Extract metrics from log history βββββββββββββββββββββββββββββββββββββββββ\n",
"# V4.2.1: TRL 0.24.0 logs under \"reward\" / \"rewards/commerce_reward_fn/mean\"\n",
"# and \"grad_norm\" (no \"train/\" prefix in log_history entries).\n",
"rewards = []\n",
"reward_stds = []\n",
"grad_norms = []\n",
"for entry in probe_trainer.state.log_history:\n",
" if \"rewards/commerce_reward_fn/mean\" in entry:\n",
" rewards.append(entry[\"rewards/commerce_reward_fn/mean\"])\n",
" if \"rewards/commerce_reward_fn/std\" in entry:\n",
" reward_stds.append(entry[\"rewards/commerce_reward_fn/std\"])\n",
" if \"grad_norm\" in entry:\n",
" grad_norms.append(entry[\"grad_norm\"])\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(f\"PROBE RESULTS ({elapsed:.0f}s, {elapsed/10:.0f}s/step)\")\n",
"print(f\" Rewards: {[f'{r:.3f}' for r in rewards]}\")\n",
"print(f\" Reward stds: {[f'{s:.3f}' for s in reward_stds]}\")\n",
"print(f\" Grad norms: {[f'{g:.4f}' for g in grad_norms]}\")\n",
"print(f\" Train loss: {result.training_loss:.4f}\")\n",
"print(f\"{'='*60}\")\n",
"\n",
"if rewards and max(rewards) > 0:\n",
" print(\"β Model generates scoreable output\")\n",
"else:\n",
" print(\"β WARNING: All rewards are 0. Check reward functions.\")\n",
"\n",
"if grad_norms and max(grad_norms) > 0:\n",
" print(\"β Gradients are flowing\")\n",
"else:\n",
" print(\"β WARNING: All grad_norms are 0. Check model/LoRA setup.\")\n",
"\n",
"if reward_stds and min(reward_stds) > 0:\n",
" print(\"β Batches have reward variance (GRPO has signal)\")\n",
"elif reward_stds:\n",
" n_zero = sum(1 for s in reward_stds if s < 1e-6)\n",
" print(f\"β οΈ WARNING: {n_zero}/{len(reward_stds)} steps had zero reward std. Consider increasing G.\")\n",
"else:\n",
" print(\"β οΈ WARNING: No reward_std logged. Check TRL version.\")\n",
"\n",
"print(\"\\nβ Proceed to full training (Cell 13)\")\n",
"\n",
"del probe_trainer\n",
"gc.collect(); torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 13: W&B Init + Full Training (1,500 Steps)\n\n**V4.2 changes:**\n- **MAX_STEPS=1,500** (multi-epoch, ~2.5Γ full dataset) (Change 4)\n- **EvalRewardCallback v2:** 65 stratified samples, per-task 95% CIs, GDPO normalization logging, dynamic task weight updates, **best checkpoint saving** (Changes 1, 5, 6, 8)\n- **`SAVE_STEPS=100`, `EVAL_STEPS=50`** (scaled for longer run)\n- **Seed in W&B config** for multi-seed tracking (Change 7)\n- **Best checkpoint saved explicitly** when eval improves (Change 8)"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import wandb\nfrom transformers import TrainerCallback\n\nwandb.login()\nwandb.init(\n project=WANDB_PROJECT,\n name=f\"grpo-v4.2-instruct-0.5B-seed{CURRENT_SEED}-{time.strftime('%Y%m%d-%H%M')}\",\n config={\n \"model_id\": MODEL_ID,\n \"version\": \"v4.2\",\n \"seed\": CURRENT_SEED,\n \"seeds_planned\": SEEDS,\n \"num_generations\": NUM_GENERATIONS,\n \"max_completion_length\": MAX_COMPLETION_LENGTH,\n \"temperature\": TEMPERATURE,\n \"learning_rate\": LEARNING_RATE,\n \"lr_scheduler_type\": LR_SCHEDULER_TYPE,\n \"warmup_ratio\": WARMUP_RATIO,\n \"beta\": BETA,\n \"scale_rewards\": SCALE_REWARDS,\n \"batch_size\": BATCH_SIZE,\n \"grad_accum\": GRAD_ACCUM,\n \"max_steps\": MAX_STEPS,\n \"lora_r\": LORA_R,\n \"lora_alpha\": LORA_ALPHA,\n \"train_prompts\": len(train_dataset),\n \"eval_prompts\": len(eval_dataset),\n \"eval_stratified\": True,\n \"eval_per_task\": EVAL_SAMPLES_PER_TASK,\n \"repetition_penalty_override\": 1.0,\n \"json_parser\": \"json-repair + PT-BR decimal normalizer\",\n \"sql_reward\": \"v2 (validation-aware, 4-tier)\",\n \"gdpo_normalization\": True,\n \"dynamic_task_weighting\": \"MT-GRPO IWU\",\n \"changes_from_v41\": \"stratified eval 65, reward audit, SQL v2, 1500 steps, GDPO, IWU, 3 seeds, best ckpt\",\n },\n)\nprint(f\"β W&B run: {wandb.run.url}\")\n\n\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# V4.2: EvalRewardCallback v2\n# - Uses 65 stratified eval samples (Change 1)\n# - Reports per-task means with 95% CIs (Change 1)\n# - Runs GDPO normalization and logs component stats (Change 5)\n# - Updates dynamic task weights via IWU (Change 6)\n# - Saves best checkpoint explicitly (Change 8)\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\nclass EvalRewardCallbackV2(TrainerCallback):\n def __init__(self, eval_records, reward_fn, patience, delta):\n self.eval_records = eval_records\n self.reward_fn = reward_fn\n self.patience = patience\n self.delta = delta\n self.best_reward = -float(\"inf\")\n self.best_step = 0\n self.no_improve_count = 0\n\n def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n return control\n\n tokenizer_local = processing_class\n if tokenizer_local is None:\n print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n return control\n\n mean_reward, per_task, per_task_all = self._run_eval(model, tokenizer_local, args)\n improved = mean_reward > self.best_reward + self.delta\n\n # ββ Per-task 95% CIs (Change 1) ββββββββββββββββββββββββββββββββββββββ\n log_data = {\n \"eval/mean_reward\": mean_reward,\n \"eval/best_reward\": max(self.best_reward, mean_reward),\n \"eval/no_improve_count\": self.no_improve_count,\n }\n \n ci_strs = []\n for task_name, task_rewards in per_task_all.items():\n if task_rewards:\n n = len(task_rewards)\n task_mean = sum(task_rewards) / n\n if n > 1:\n task_std = (sum((r - task_mean)**2 for r in task_rewards) / (n - 1)) ** 0.5\n ci_half = 1.96 * task_std / math.sqrt(n)\n else:\n ci_half = 0.0\n log_data[f\"eval/{task_name}\"] = task_mean\n log_data[f\"eval/{task_name}_ci\"] = ci_half\n log_data[f\"eval/{task_name}_n\"] = n\n ci_strs.append(f\"{task_name}={task_mean:.3f}Β±{ci_half:.3f} (n={n})\")\n \n # ββ GDPO per-component stats (Change 5) βββββββββββββββββββββββββββββ\n if per_task_all and all(len(v) > 0 for v in per_task_all.values()):\n try:\n gdpo_rewards = gdpo_normalize(per_task_all)\n log_data[\"eval/gdpo_mean\"] = sum(gdpo_rewards) / len(gdpo_rewards)\n log_data[\"eval/gdpo_std\"] = (sum((r - sum(gdpo_rewards)/len(gdpo_rewards))**2 for r in gdpo_rewards) / len(gdpo_rewards)) ** 0.5\n except Exception as e:\n print(f\" [GDPO] normalization error: {e}\")\n \n # ββ Dynamic task weight update (Change 6) βββββββββββββββββββββββββββ\n per_task_means = {}\n for task_name, task_rewards in per_task_all.items():\n if task_rewards:\n per_task_means[task_name] = sum(task_rewards) / len(task_rewards)\n \n update_task_weights(state.global_step, per_task_means, update_interval=EVAL_STEPS)\n \n for task_name, weight in _task_weights.items():\n log_data[f\"sampler/{task_name}_weight\"] = weight\n \n wandb.log(log_data, step=state.global_step)\n\n status = \"β improved\" if improved else f\"β no gain ({self.no_improve_count + 1}/{self.patience})\"\n print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n for cs in ci_strs:\n print(f\" {cs}\")\n print(f\" Task weights: {', '.join(f'{t}={w:.3f}' for t, w in _task_weights.items())}\")\n\n if improved:\n self.best_reward = mean_reward\n self.best_step = state.global_step\n self.no_improve_count = 0\n # ββ V4.2: Save best checkpoint explicitly (Change 8) βββββββββββββ\n best_path = ADAPTER_DIR / \"best_checkpoint\"\n best_path.mkdir(parents=True, exist_ok=True)\n model.save_pretrained(str(best_path))\n tokenizer_local.save_pretrained(str(best_path))\n print(f\" β Best checkpoint saved β {best_path} (reward={mean_reward:.4f})\")\n else:\n self.no_improve_count += 1\n if self.no_improve_count >= self.patience:\n print(f\"[EarlyStopping] No improvement for {self.patience} evals. Halting.\")\n control.should_training_stop = True\n return control\n\n def _run_eval(self, model, tokenizer_local, args):\n FastLanguageModel.for_inference(model)\n rewards = []\n per_task_summary = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n per_task_all = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n \n # V4.2: Use ALL stratified eval samples (65), not just 15\n for record in self.eval_records:\n msgs = record[\"prompt\"]\n text = tokenizer_local.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n user_txt = \" \".join(m.get(\"content\", \"\") for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_txt)\n\n inputs = tokenizer_local(text, return_tensors=\"pt\", truncation=True, max_length=args.max_prompt_length).to(model.device)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=EVAL_MAX_TOKENS,\n temperature=0.1,\n do_sample=True,\n repetition_penalty=1.0,\n )\n resp = tokenizer_local.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n r = self.reward_fn([resp], [text])[0]\n rewards.append(r)\n if task in per_task_all:\n per_task_all[task].append(r)\n per_task_summary[task].append(r)\n \n FastLanguageModel.for_training(model)\n mean_r = sum(rewards) / len(rewards) if rewards else 0.0\n return mean_r, per_task_summary, per_task_all\n\n\n# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nFastLanguageModel.for_training(model)\n\ngrpo_config = GRPOConfig(\n output_dir=str(CHECKPOINT_DIR),\n num_generations=NUM_GENERATIONS,\n scale_rewards=SCALE_REWARDS,\n max_completion_length=MAX_COMPLETION_LENGTH,\n max_steps=MAX_STEPS, # V4.2: 1,500\n temperature=TEMPERATURE,\n beta=BETA,\n num_train_epochs=1,\n per_device_train_batch_size=BATCH_SIZE,\n gradient_accumulation_steps=GRAD_ACCUM,\n learning_rate=LEARNING_RATE,\n lr_scheduler_type=LR_SCHEDULER_TYPE,\n warmup_ratio=WARMUP_RATIO,\n fp16=False,\n bf16=True,\n logging_steps=1,\n save_steps=SAVE_STEPS, # V4.2: 100\n save_total_limit=5,\n save_only_model=True,\n report_to=\"wandb\",\n max_prompt_length=MAX_SEQ_LENGTH // 2,\n seed=CURRENT_SEED, # V4.2: per-seed\n remove_unused_columns=False,\n disable_tqdm=True,\n logging_first_step=True,\n)\n\neval_cb = EvalRewardCallbackV2(\n eval_records=list(eval_dataset),\n reward_fn=commerce_reward_fn_raw, # V4.2: raw rewards for eval (no GDPO/IWU distortion)\n patience=EARLY_STOPPING_PATIENCE,\n delta=EARLY_STOPPING_DELTA,\n)\n\ntrainer = UnslothGRPOTrainer(\n model=model,\n reward_funcs=commerce_reward_fn,\n args=grpo_config,\n train_dataset=train_dataset,\n processing_class=tokenizer,\n callbacks=[eval_cb],\n)\n\nt_start = time.time()\nresult = trainer.train()\nelapsed = time.time() - t_start\n\nwandb.log({\n \"train/final_loss\": result.training_loss,\n \"train/duration_hours\": elapsed / 3600,\n \"train/total_steps\": result.global_step,\n \"eval/best_reward_final\": eval_cb.best_reward,\n \"eval/best_step\": eval_cb.best_step,\n \"final/task_weights\": _task_weights,\n})\nwandb.finish()\n\nprint(f\"\\n{'='*60}\")\nprint(f\"V4.2 Training Complete (seed={CURRENT_SEED})\")\nprint(f\" Loss: {result.training_loss:.4f}\")\nprint(f\" Steps: {result.global_step}\")\nprint(f\" Duration: {elapsed/3600:.1f}h\")\nprint(f\" Best eval: {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\nprint(f\" Final task weights: {_task_weights}\")\nprint(f\"{'='*60}\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 14: Post-Training Validation (65 Stratified Samples)\n\n**V4.2:** Full stratified eval with per-task 95% CIs.\n\nReports `mean Β± 1.96 Γ std/βn` for each task.\n\n**The four questions V4.2 must answer:**\n1. Does SQL reward improve with the new reward function?\n2. Is the insights regression noise or forgetting?\n3. Does multi-epoch training push eval above 0.70?\n4. Are results reproducible across seeds?"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "FastLanguageModel.for_inference(model)\n\nval_samples = list(eval_dataset) # All 65 stratified samples\nval_results = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n\nprint(f\"Post-training validation on {len(val_samples)} stratified samples (seed={CURRENT_SEED})\")\nprint(\"-\" * 80)\n\nfor i, record in enumerate(val_samples):\n msgs = record[\"prompt\"]\n user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n task = _classify_task_type(user_text)\n\n text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=MAX_COMPLETION_LENGTH,\n temperature=0.1,\n do_sample=True,\n repetition_penalty=1.0,\n )\n resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n r = commerce_reward_fn_raw([resp], [text])[0] # Raw rewards for reporting\n val_results[task].append(r)\n if i < 10 or r < 0.2: # Print first 10 and all low-scoring\n print(f\" [{task:12s}] reward={r:.3f} | {strip_think(resp)[:80]}\")\n\n# ββ Results with 95% CIs ββββββββββββββββββββββββββββββββββββββββββββββββββββ\nprint(f\"\\n{'='*80}\")\nprint(f\"VALIDATION RESULTS β V4.2 Seed {CURRENT_SEED}\")\nprint(f\"{'='*80}\")\nprint(f\"{'Task':15s} {'Mean':>8s} {'Β± 95% CI':>10s} {'Min':>6s} {'Max':>6s} {'N':>4s}\")\nprint(\"-\" * 55)\n\noverall = []\nresults_by_seed = {} # Store for cross-seed comparison\n\nfor task in [\"extraction\", \"sql_qa\", \"insights\", \"push\"]:\n rewards = val_results[task]\n overall.extend(rewards)\n if rewards:\n n = len(rewards)\n mean_r = sum(rewards) / n\n if n > 1:\n std_r = (sum((r - mean_r)**2 for r in rewards) / (n - 1)) ** 0.5\n ci_half = 1.96 * std_r / math.sqrt(n)\n else:\n std_r = 0.0\n ci_half = 0.0\n print(f\"{task:15s} {mean_r:8.3f} {'Β±':>2s}{ci_half:7.3f} {min(rewards):6.3f} {max(rewards):6.3f} {n:4d}\")\n results_by_seed[task] = {\"mean\": mean_r, \"ci\": ci_half, \"n\": n, \"std\": std_r}\n\nif overall:\n n_total = len(overall)\n mean_total = sum(overall) / n_total\n std_total = (sum((r - mean_total)**2 for r in overall) / (n_total - 1)) ** 0.5\n ci_total = 1.96 * std_total / math.sqrt(n_total)\n print(\"-\" * 55)\n print(f\"{'OVERALL':15s} {mean_total:8.3f} {'Β±':>2s}{ci_total:7.3f} {min(overall):6.3f} {max(overall):6.3f} {n_total:4d}\")\n results_by_seed[\"overall\"] = {\"mean\": mean_total, \"ci\": ci_total, \"n\": n_total, \"std\": std_total}\n\n# ββ Save results for cross-seed comparison ββββββββββββββββββββββββββββββββββ\nresults_file = ADAPTER_DIR / f\"eval_results_seed{CURRENT_SEED}.json\"\nresults_file.parent.mkdir(parents=True, exist_ok=True)\nwith open(results_file, \"w\") as f:\n json.dump(results_by_seed, f, indent=2)\nprint(f\"\\nβ Results saved to {results_file}\")\n\n# ββ V4.2 Decision βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\nprint(f\"\\n--- V4.2 Questions ---\")\nsql_mean = results_by_seed.get(\"sql_qa\", {}).get(\"mean\", 0)\ninsights_mean = results_by_seed.get(\"insights\", {}).get(\"mean\", 0)\noverall_mean = results_by_seed.get(\"overall\", {}).get(\"mean\", 0)\n\nprint(f\"Q1 SQL reward: {sql_mean:.3f} ({'improved' if sql_mean > 0.60 else 'still stagnant' if sql_mean < 0.56 else 'modest gain'})\")\nprint(f\"Q2 Insights: {insights_mean:.3f} ({'stable' if insights_mean > 0.70 else 'regressed' if insights_mean < 0.60 else 'mixed'})\")\nprint(f\"Q3 Overall: {overall_mean:.3f} ({'above 0.70 target' if overall_mean > 0.70 else 'below target'})\")\nprint(f\"Q4 Seeds: Seed {CURRENT_SEED} done. Run seeds {[s for s in SEEDS if s != CURRENT_SEED]} next.\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 15: Save Adapter\n\n**V4.2:** Saves from `best_checkpoint/` (peak eval reward) instead of last training step."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# V4.2: Save the best checkpoint, not the last training step\nbest_checkpoint_path = ADAPTER_DIR / \"best_checkpoint\"\n\nif best_checkpoint_path.exists():\n print(f\"β Best checkpoint exists at {best_checkpoint_path}\")\n print(f\" Best eval reward: {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\n # Copy best checkpoint to main adapter dir for easy loading\n import shutil\n final_path = ADAPTER_DIR / \"final\"\n if final_path.exists():\n shutil.rmtree(final_path)\n shutil.copytree(str(best_checkpoint_path), str(final_path))\n print(f\" β Copied to {final_path}\")\nelse:\n print(\"β οΈ No best_checkpoint found. Saving current model state.\")\n ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\n model.save_pretrained(str(ADAPTER_DIR / \"final\"))\n tokenizer.save_pretrained(str(ADAPTER_DIR / \"final\"))\n\nprint(f\"\\nβ Adapter saved for seed {CURRENT_SEED}\")\nprint(f\" Location: {ADAPTER_DIR / 'final'}\")\nprint(f\" Best eval: {eval_cb.best_reward:.4f} at step {eval_cb.best_step}\")"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "---\n\n## Cell 16: Results Table Generation (Change 7)\n\n**Run this after ALL 3 seeds are complete.**\n\nReads `eval_results_seedN.json` from each seed directory and produces the cross-seed\nresults table with mean Β± 95% CI.\n\n```\n| Task | Seed 42 | Seed 123 | Seed 456 | Mean Β± 95% CI |\n|---|---|---|---|---|\n| Extraction | ... | ... | ... | X.XX Β± 0.0X |\n| SQL Q&A | ... | ... | ... | X.XX Β± 0.0X |\n| Insights | ... | ... | ... | X.XX Β± 0.0X |\n| Push | ... | ... | ... | X.XX Β± 0.0X |\n| **Mean** | ... | ... | ... | **X.XX Β± 0.0X** |\n```"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n# V4.2: Cross-Seed Results Table (Change 7)\n# Run this AFTER completing all 3 seeds (42, 123, 456)\n# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n\nresults_by_seed = {}\nmissing_seeds = []\n\nfor seed in SEEDS:\n seed_dir = MODELS_DIR / f\"tucano2-0.5B-instruct-grpo-v4.2-seed{seed}\"\n results_file = seed_dir / f\"eval_results_seed{seed}.json\"\n if results_file.exists():\n with open(results_file) as f:\n results_by_seed[seed] = json.load(f)\n print(f\"β Loaded results for seed {seed}\")\n else:\n missing_seeds.append(seed)\n print(f\"β οΈ Missing results for seed {seed} (run the notebook with CURRENT_SEED={seed})\")\n\nif missing_seeds:\n print(f\"\\nβ οΈ Missing seeds: {missing_seeds}\")\n print(f\" Complete these runs before generating the final table.\")\n print(f\" Change CURRENT_SEED in Cell 3 and re-run Cells 3-15.\")\n\nif len(results_by_seed) >= 2:\n print(f\"\\n{'='*90}\")\n print(f\"V4.2 CROSS-SEED RESULTS TABLE\")\n print(f\"{'='*90}\")\n \n tasks = [\"extraction\", \"sql_qa\", \"insights\", \"push\", \"overall\"]\n \n # Header\n header = f\"{'Task':15s}\"\n for seed in SEEDS:\n if seed in results_by_seed:\n header += f\" {'Seed '+str(seed):>10s}\"\n header += f\" {'Mean Β± 95% CI':>18s}\"\n print(header)\n print(\"-\" * len(header))\n \n for task in tasks:\n row = f\"{task:15s}\"\n seed_means = []\n for seed in SEEDS:\n if seed in results_by_seed and task in results_by_seed[seed]:\n m = results_by_seed[seed][task][\"mean\"]\n seed_means.append(m)\n row += f\" {m:10.3f}\"\n elif seed in results_by_seed:\n row += f\" {'β':>10s}\"\n \n if len(seed_means) >= 2:\n cross_mean = sum(seed_means) / len(seed_means)\n cross_std = (sum((m - cross_mean)**2 for m in seed_means) / (len(seed_means) - 1)) ** 0.5\n # With 3 seeds, use t-distribution critical value (t=4.303 for 95% CI, df=2)\n # But for consistency with the handoff, use Β±std\n row += f\" {cross_mean:7.3f} Β± {cross_std:.3f}\"\n elif len(seed_means) == 1:\n row += f\" {seed_means[0]:7.3f} (1 seed)\"\n \n if task == \"overall\":\n row = f\"**{row.strip()}**\"\n print(row)\n \n print(f\"\\n{'='*90}\")\n \n # ββ Reproducibility assessment ββββββββββββββββββββββββββββββββββββββββββ\n if len(results_by_seed) == 3:\n overall_means = [results_by_seed[s][\"overall\"][\"mean\"] for s in SEEDS if s in results_by_seed]\n overall_std = (sum((m - sum(overall_means)/len(overall_means))**2 for m in overall_means) / (len(overall_means) - 1)) ** 0.5\n print(f\"\\nReproducibility: overall std = {overall_std:.4f}\")\n if overall_std < 0.03:\n print(f\" β
Robust (std < 0.03): results are reproducible across seeds\")\n elif overall_std < 0.05:\n print(f\" β οΈ Moderate (0.03 < std < 0.05): some initialization sensitivity\")\n else:\n print(f\" β High variance (std > 0.05): significant initialization sensitivity\")\n\nelse:\n print(f\"\\nNeed at least 2 seeds to generate comparison table.\")\n print(f\"Current seed ({CURRENT_SEED}) results:\")\n if CURRENT_SEED in results_by_seed:\n for task, data in results_by_seed[CURRENT_SEED].items():\n print(f\" {task}: {data['mean']:.3f} Β± {data.get('ci', 0):.3f}\")"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
} |