File size: 90,795 Bytes
c5f1d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41eb15f
63b1c86
41eb15f
c5f1d2d
 
 
 
 
 
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
63b1c86
 
 
 
 
41eb15f
63b1c86
 
 
 
 
41eb15f
63b1c86
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be31c
 
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be31c
 
 
 
 
 
 
 
 
 
 
 
 
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be31c
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be31c
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1be31c
41eb15f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f1d2d
 
 
 
0fc9042
c5f1d2d
 
 
 
 
 
b1be31c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f1d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
080fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f1d2d
 
 
 
 
 
 
 
 
 
 
c95e44c
c5f1d2d
 
 
 
 
 
 
 
 
 
 
c95e44c
c5f1d2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "# Tucano2 Commerce β€” GRPO Training V4.2 (Gold Standard, 0.5B)\n\n**Reference:** `docs/v4_2-handoff.md`  \n**Base:** V4.1 notebook with 8 targeted changes\n\n## V4.2 Changes from V4.1\n\n| # | Change | V4.1 | V4.2 | Why |\n|---|--------|------|------|-----|\n| 1 | Eval suite | 15 mixed samples | **65 stratified** (20 ext + 15 sql + 15 ins + 15 push) | Insights swing Β±0.22 was eval noise on nβ‰ˆ2 |\n| 2 | Reward audit | None | **Spearman ρ > 0.70 gate** (20 completions, human-scored) | Parser bug persisted 3 versions |\n| 3 | SQL reward | Heuristic vocabulary | **Validation-aware** (SQL syntax + query/explanation + numerics + domain) | SQL stagnant +3.8% β€” reward was ceiling |\n| 4 | Max steps | 600 | **1,500** (~2.5 epochs) | Only 40% data seen; eval still improving at step 500 |\n| 5 | GDPO normalization | Batch-level reward | **Per-component normalization** before aggregation | GDPO Β§3.1: preserves 4Γ— more advantage groups |\n| 6 | Task weighting | Equal (0.25 each) | **Dynamic IWU** (upweight stagnating tasks) | MT-GRPO Β§3.2: prevents easy-task collapse |\n| 7 | Seeds | Single run (42) | **3 seeds** (42, 123, 456) with reported CIs | Minimum for credible ML result |\n| 8 | Best checkpoint | Save at end | **Explicit best_checkpoint/** on eval improvement | GRPOTrainer lacks load_best_model_at_end |\n\n**Prerequisites:**\n- Upload `data/pairs/train.jsonl` and `data/pairs/eval.jsonl` to `./data/pairs/`\n- Hardware: L4 (24GB), PyTorch kernel, bf16 supported\n- Estimated runtime: ~12h per seed (1,500 steps Γ— ~30s/step)\n- Run 3 times, changing only `CURRENT_SEED` in Cell 3\n\n---\n\n*V4.2 is the last 0.5B run. Its purpose is not to find more improvement β€” it is to know exactly what was found and why, with enough statistical rigor to say so in writing.*"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 1: Dependencies\n\n**Gate:** No errors. Verify TRL 0.24.0 installed.\n\n**V4.2 additions:** `scipy` for Spearman ρ in reward audit."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Cell 1 β€” Clean install\n# Run after kernel restart\n\n!pip install \"unsloth\"\n!pip install \"trl==0.24.0\" --no-deps\n!pip install \"rich\" \"wandb\"\n!pip install \"json-repair\"  # V4.1: robust JSON parser for Portuguese LLM output\n!pip install \"scipy\"        # V4.2: Spearman ρ for reward audit"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 2: GPU + Unsloth Verification\n\n**Gate:** CUDA available, bf16=True, VRAM > 20GB, TRL 0.24.0."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import torch\n\nprint(f\"CUDA available: {torch.cuda.is_available()}\")\nprint(f\"GPU: {torch.cuda.get_device_name(0)}\")\nprint(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\nprint(f\"bf16 support: {torch.cuda.is_bf16_supported()}\")\n\nfrom unsloth import FastLanguageModel\nprint(f\"\\nβœ“ Unsloth loaded\")\n\nimport trl\nassert trl.__version__ == \"0.24.0\", f\"Expected TRL 0.24.0, got {trl.__version__}\"\nprint(f\"βœ“ TRL {trl.__version__}\")\n\nimport transformers\nprint(f\"βœ“ Transformers {transformers.__version__}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 3: Config Constants\n\n**V4.2 changes:**\n- `MAX_STEPS`: 600 β†’ **1,500** (multi-epoch, ~2.5Γ— full dataset)\n- `EVAL_STEPS`: 20 β†’ **50** (more frequent eval relative to epoch boundaries)\n- `SAVE_STEPS`: 50 β†’ **100** (scaled for longer run)\n- `SEEDS`: Added multi-seed support. Change `CURRENT_SEED` per run.\n- `EVAL_TOTAL = 65`: Stratified eval set (20 ext + 15 sql + 15 ins + 15 push)\n\n**Everything else UNCHANGED from V4.1** (validated config)."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import os\nimport json\nimport re\nimport time\nimport random\nimport gc\nimport math\nimport warnings\nfrom pathlib import Path\n\n# ── Suppress noisy deprecation warnings from Transformers 5.5.0 ──────────────\nwarnings.filterwarnings(\"ignore\", message=\".*AttentionMaskConverter.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*Passing `generation_config` together with.*\")\nwarnings.filterwarnings(\"ignore\", message=\".*max_new_tokens.*max_length.*\")\nwarnings.filterwarnings(\"ignore\", category=FutureWarning)\n\n# ── Disable Unsloth kernel recompilation ─────────────────────────────────────\nos.environ[\"UNSLOTH_COMPILE_DISABLE\"] = \"1\"\nos.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n\n# ── V4.2: Multi-seed support ────────────────────────────────────────────────\nSEEDS = [42, 123, 456]\nCURRENT_SEED = 42   # ← CHANGE THIS PER RUN (42, 123, 456)\n\n# Set all random seeds\nrandom.seed(CURRENT_SEED)\ntorch.manual_seed(CURRENT_SEED)\nif torch.cuda.is_available():\n    torch.cuda.manual_seed_all(CURRENT_SEED)\n\n# ── Model ────────────────────────────────────────────────────────────────────\nMODEL_ID       = \"Polygl0t/Tucano2-qwen-0.5B-Instruct\"\nMAX_SEQ_LENGTH = 2048   # model supports 4096, but 2048 is plenty for Instruct (no <think> overhead)\nMODELS_DIR     = Path(\"/home/jupyter/tucano2/models\")\nADAPTER_DIR    = MODELS_DIR / f\"tucano2-0.5B-instruct-grpo-v4.2-seed{CURRENT_SEED}\"\nCHECKPOINT_DIR = ADAPTER_DIR / \"checkpoints\"\n\n# ── Data ─────────────────────────────────────────────────────────────────────\nDATA_DIR       = Path(\"/home/jupyter/tucano2/data\")\nTRAIN_FILE     = DATA_DIR / \"pairs\" / \"train.jsonl\"\nEVAL_FILE      = DATA_DIR / \"pairs\" / \"eval.jsonl\"   # V4.2: separate eval source\n\n# V4.2: Stratified eval set specification (Change 1)\nEVAL_SAMPLES_PER_TASK = {\n    \"extraction\": 20,\n    \"sql_qa\":     15,\n    \"insights\":   15,\n    \"push\":       15,\n}\nEVAL_TOTAL = sum(EVAL_SAMPLES_PER_TASK.values())  # 65\n\n# ── GRPO Hyperparameters ─────────────────────────────────────────────────────\n# V4.2 CHANGES: MAX_STEPS 600β†’1500, EVAL_STEPS 20β†’50, SAVE_STEPS 50β†’100\n# Everything else UNCHANGED from V4.1 (validated)\nNUM_GENERATIONS        = 16     # 0.5B + short completions = VRAM allows G=16\nMAX_COMPLETION_LENGTH  = 512    # Instruct: no <think> overhead\nTEMPERATURE            = 1.0    # Skywork-OR1: Ο„=1.0 for exploration\nLEARNING_RATE          = 5e-6   # V4.1: validated at 5e-6\nBETA                   = 0.0    # Dr. GRPO Β§3.2: Ξ²=0 optimal for rule-based rewards\nSCALE_REWARDS          = False  # Dr. GRPO: remove std normalization bias\nBATCH_SIZE             = 2      # per-device batch size\nGRAD_ACCUM             = 1      # effective batch = 2 * 1 = 2 prompts * 16 gen = 32 completions\nMAX_STEPS              = 1500   # V4.2: was 600. ~2.5 full epochs with shuffling\nSAVE_STEPS             = 100    # V4.2: was 50. Scaled for longer run\nEVAL_STEPS             = 50     # V4.2: was 20. More frequent per-epoch boundary\nEARLY_STOPPING_PATIENCE = 15    # 15 Γ— 50 = 750 steps without improvement\nEARLY_STOPPING_DELTA   = 0.005\nLR_SCHEDULER_TYPE      = \"constant_with_warmup\"  # V4.1: validated\nWARMUP_RATIO           = 0.05   # V4.1: validated (5% of 1500 = 75 warmup steps)\n\n# ── LoRA ─────────────────────────────────────────────────────────────────────\nLORA_R     = 16\nLORA_ALPHA = 32\n\n# ── Monitoring ───────────────────────────────────────────────────────────────\nWANDB_PROJECT     = \"tucano2-commerce\"\nEVAL_MAX_TOKENS   = 512   # match training completion length\n\n# ── Task Classification (inherited from V2/V3) ──────────────────────────────\nVALID_SENTIMENTS  = {\"positive\", \"negative\", \"neutral\"}\nVALID_CATEGORIES  = {\n    \"delivery_delay\", \"product_quality\", \"product_not_received\",\n    \"wrong_product\", \"seller_communication\", \"app_issue\",\n    \"price_value\", \"other\", \"none\",\n}\nVALID_CHURN  = {\"low\", \"medium\", \"high\"}\nVALID_REPEAT = {\"yes\", \"no\", \"maybe\"}\nEXTRACTION_FIELDS = [\n    \"sentiment\", \"sentiment_score\", \"churn_risk\", \"delivery_issue\",\n    \"product_issue\", \"seller_issue\", \"main_complaint\",\n    \"complaint_category\", \"repeat_intent\", \"would_recommend\",\n]\n\n# ── Verified Special Token IDs (from tokenizer_config.json) ─────────────────\nTOKEN_ID_BOS       = 1      # <|im_start|>\nTOKEN_ID_EOS       = 2      # <|im_end|>\nTOKEN_ID_PAD       = 49109  # <|pad|>\nTOKEN_ID_THINK     = 49116  # <think>\nTOKEN_ID_THINK_END = 49117  # </think>\n\n# ══════════════════════════════════════════════════════════════════════════════\n# TASK-AWARE SYSTEM PROMPTS (inherited from V3/V4)\n# ══════════════════════════════════════════════════════════════════════════════\n\nSYSTEM_EXTRACTION = (\n    \"VocΓͺ Γ© um motor de extraΓ§Γ£o de dados de e-commerce brasileiro. \"\n    \"Retorne APENAS um objeto JSON vΓ‘lido, sem nenhum texto antes ou depois. \"\n    \"NΓƒO USE blocos de cΓ³digo markdown (```json). \"\n    \"O primeiro caractere da sua resposta deve ser { e o ΓΊltimo deve ser }. \"\n    \"Campos nΓ£o mencionados na avaliaΓ§Γ£o devem ser null β€” nunca invente valores. \"\n    \"Sem explicaΓ§Γ£o. Sem comentΓ‘rios.\"\n)\n\nSYSTEM_SQL = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para consultas e anΓ‘lises de dados: apresente a resposta de forma direta \"\n    \"com nΓΊmeros e dados concretos. Seja conciso.\"\n)\n\nSYSTEM_INSIGHTS = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para anΓ‘lises estratΓ©gicas: raciocine de forma estruturada e concisa, \"\n    \"focando nos pontos principais e recomendaΓ§Γ΅es acionΓ‘veis.\"\n)\n\nSYSTEM_PUSH = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\\n\\n\"\n    \"Para notificaΓ§Γ΅es push: seja direto e criativo. \"\n    \"A notificaΓ§Γ£o deve ter no mΓ‘ximo 120 caracteres. \"\n    \"Responda diretamente.\"\n)\n\nSYSTEM_PT = (\n    \"VocΓͺ Γ© um assistente de IA especializado em anΓ‘lise de e-commerce brasileiro. \"\n    \"VocΓͺ compreende avaliaΓ§Γ΅es de clientes em portuguΓͺs e padrΓ΅es de comΓ©rcio brasileiro.\"\n)\n\ndef get_system_prompt(task_type: str) -> str:\n    return {\n        \"extraction\": SYSTEM_EXTRACTION,\n        \"sql_qa\": SYSTEM_SQL,\n        \"insights\": SYSTEM_INSIGHTS,\n        \"push\": SYSTEM_PUSH,\n    }.get(task_type, SYSTEM_PT)\n\ndef inject_task_system_prompt(msgs, task_type):\n    new_msgs = []\n    system_prompt = get_system_prompt(task_type)\n    has_system = False\n    for m in msgs:\n        if m[\"role\"] == \"system\":\n            new_msgs.append({\"role\": \"system\", \"content\": system_prompt})\n            has_system = True\n        else:\n            new_msgs.append(m)\n    if not has_system:\n        new_msgs.insert(0, {\"role\": \"system\", \"content\": system_prompt})\n    return new_msgs\n\nprint(\"βœ“ Task-aware system prompts defined\")\n\nprint(\"βœ“ Config loaded\")\nprint(f\"  Model: {MODEL_ID}\")\nprint(f\"  Seed: {CURRENT_SEED} (run {SEEDS.index(CURRENT_SEED)+1}/{len(SEEDS)})\")\nprint(f\"  G={NUM_GENERATIONS}, max_comp={MAX_COMPLETION_LENGTH}, temp={TEMPERATURE}\")\nprint(f\"  LR={LEARNING_RATE}, Ξ²={BETA}, scale_rewards={SCALE_REWARDS}\")\nprint(f\"  LR schedule: {LR_SCHEDULER_TYPE}, warmup={WARMUP_RATIO}\")\nprint(f\"  LoRA r={LORA_R}, Ξ±={LORA_ALPHA}\")\nprint(f\"  Max steps: {MAX_STEPS} (~{MAX_STEPS * BATCH_SIZE / 1480:.1f} epochs)\")\nprint(f\"  Eval: {EVAL_TOTAL} stratified samples, every {EVAL_STEPS} steps\")\nprint(f\"  Save every {SAVE_STEPS} steps\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 4: Load Model + Apply Critical Overrides\n\n**Gate:** Model loaded, `use_cache=True`, `repetition_penalty=1.0`, `temperature=1.0`.\n\nUnchanged from V4.1."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from unsloth import FastLanguageModel\n\nprint(\"Loading model...\")\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n    model_name=MODEL_ID,\n    max_seq_length=MAX_SEQ_LENGTH,\n    load_in_4bit=True,\n    dtype=None,\n)\n\nfrom peft import LoraConfig, get_peft_model\n\nlora_config = LoraConfig(\n    r=LORA_R,\n    lora_alpha=LORA_ALPHA,\n    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n                    \"gate_proj\", \"up_proj\", \"down_proj\"],\n    lora_dropout=0,\n    bias=\"none\",\n    task_type=\"CAUSAL_LM\",\n)\nmodel = get_peft_model(model, lora_config)\nmodel.print_trainable_parameters()\n\n# ═══════════════════════════════════════════════════════════════════════════════\n# CRITICAL OVERRIDES\n# ═══════════════════════════════════════════════════════════════════════════════\n\nmodel.config.use_cache = True\nmodel.generation_config.use_cache = True\nmodel.generation_config.temperature = TEMPERATURE\nmodel.generation_config.repetition_penalty = 1.0\nmodel.generation_config.do_sample = True\nmodel.generation_config.top_k = 0\nmodel.generation_config.top_p = 1.0\nmodel.generation_config.max_length = None\n\nif tokenizer.pad_token is None:\n    tokenizer.pad_token = tokenizer.eos_token\n\nprint(f\"βœ“ Model loaded on {model.device}\")\nprint(f\"  use_cache: {model.config.use_cache}\")\nprint(f\"  temperature: {model.generation_config.temperature}\")\nprint(f\"  repetition_penalty: {model.generation_config.repetition_penalty}\")\nprint(f\"  top_k: {model.generation_config.top_k}\")\nprint(f\"  Params: {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M\")\n\ntry:\n    lm_ptr = model.base_model.model.lm_head.weight.data_ptr()\n    embed_ptr = model.base_model.model.model.embed_tokens.weight.data_ptr()\n    tied = lm_ptr == embed_ptr\n    print(f\"  Tied embeddings intact: {tied}\")\n    if not tied:\n        print(\"  ⚠️ WARNING: Tied embeddings broken after LoRA patching.\")\nexcept AttributeError as e:\n    print(f\"  ⚠️ Could not check tied embeddings: {e}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 5: Token ID Verification\n\n**Gate:** All token IDs match. Unchanged from V4.1."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "tok_tests = {\n    \"<|im_start|>\": TOKEN_ID_BOS,\n    \"<|im_end|>\":   TOKEN_ID_EOS,\n    \"<|pad|>\":      TOKEN_ID_PAD,\n    \"<think>\":      TOKEN_ID_THINK,\n    \"</think>\":     TOKEN_ID_THINK_END,\n}\n\nall_pass = True\nfor text, expected_id in tok_tests.items():\n    ids = tokenizer.encode(text, add_special_tokens=False)\n    actual_id = ids[0] if len(ids) == 1 else ids\n    match = (len(ids) == 1 and ids[0] == expected_id)\n    status = \"βœ“\" if match else \"βœ—\"\n    print(f\"  {status} '{text}' β†’ expected {expected_id}, got {actual_id}\")\n    if not match:\n        all_pass = False\n\nassert all_pass, \"Token ID mismatch detected. Update constants in Cell 3 before proceeding.\"\nprint(\"\\nβœ“ All token IDs verified\")\n\nassert tokenizer.eos_token_id == TOKEN_ID_EOS, f\"eos_token_id mismatch: {tokenizer.eos_token_id}\"\nprint(f\"βœ“ eos_token_id = {tokenizer.eos_token_id}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 6: KV Cache Diagnostic\n\n**Gate:** Ratio < 3Γ— β†’ KV cache OK. Unchanged from V4.1."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "FastLanguageModel.for_inference(model)\n\n_kv_msgs = [{\"role\": \"user\", \"content\": \"Qual a categoria de reclamaΓ§Γ£o mais frequente?\"}]\n_kv_text = tokenizer.apply_chat_template(_kv_msgs, tokenize=False, add_generation_prompt=True)\n_kv_inputs = tokenizer(_kv_text, return_tensors=\"pt\").to(model.device)\n\n_token_times, _past, _generated = [], None, _kv_inputs[\"input_ids\"]\nwith torch.no_grad():\n    for _step in range(50):\n        _t0 = time.time()\n        seq_len = _generated.shape[1]\n        if _past is None:\n            _position_ids = torch.arange(seq_len, dtype=torch.long, device=model.device).unsqueeze(0)\n        else:\n            _position_ids = torch.tensor([[seq_len - 1]], dtype=torch.long, device=model.device)\n        _out = model(\n            input_ids=_generated[:, -1:] if _past else _generated,\n            position_ids=_position_ids,\n            attention_mask=torch.ones(1, seq_len, device=model.device),\n            past_key_values=_past,\n            use_cache=True,\n            return_dict=True,\n        )\n        _past = _out.past_key_values\n        _next = _out.logits[:, -1, :].argmax(dim=-1, keepdim=True)\n        _generated = torch.cat([_generated, _next], dim=1)\n        _token_times.append(time.time() - _t0)\n\n_ratio = sum(_token_times[45:]) / max(sum(_token_times[:5]), 1e-9)\nprint(f\"First 5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[:5]]}\")\nprint(f\"Last  5 tok: {[f'{t*1000:.0f}ms' for t in _token_times[45:]]}\")\nprint(f\"Ratio last/first: {_ratio:.1f}x\")\nassert _ratio < 5, f\"KV cache BROKEN (ratio {_ratio:.1f}Γ—). Check model.config.use_cache.\"\nprint(\"βœ“ KV cache working correctly\")\n\ndel _past, _generated, _kv_inputs, _token_times, _out\ngc.collect()\ntorch.cuda.empty_cache()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected β€” requires β‰₯3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) β€” ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch β†’ we normalize per-component before returning\n- No trainer modification needed β€” normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4Γ— more distinct advantage groups (GDPO Β§3.1)\n\n### Dynamic Task Weights (Change 6) β€” ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal β†’ larger GRPO advantages β†’ more gradient\n- MT-GRPO IWU Β§3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamaΓ§Γ΅es, taxa, distribuiΓ§Γ£o, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` β€” rejects floats from PT decimal normalization\n- **Task classifier:** Reordered `_classify_task_type` β€” insights checked before push to prevent 'reengajamento' misclassification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json_repair  # V4.1: robust JSON parser for LLM output\n",
    "\n",
    "\n",
    "def strip_think(text: str) -> str:\n",
    "    \"\"\"Remove <think>...</think> block, return the answer portion.\"\"\"\n",
    "    return re.sub(r\"<think>.*?</think>\", \"\", text, flags=re.DOTALL).strip()\n",
    "\n",
    "\n",
    "def has_think_block(text: str) -> bool:\n",
    "    return bool(re.search(r\"<think>.+</think>\", text, flags=re.DOTALL))\n",
    "\n",
    "\n",
    "def _classify_task_type(prompt_text: str) -> str:\n",
    "    \"\"\"V4.2.1: reordered β€” insights before push to prevent misclassification.\n",
    "    \n",
    "    \"notificaΓ§Γ£o de reengajamento\" in a customer profile context is insights,\n",
    "    not push. Check insights keywords first.\n",
    "    \"\"\"\n",
    "    p = prompt_text.lower()\n",
    "    # 1. Insights FIRST β€” customer profile questions mentioning reengagement are insights\n",
    "    if \"perfil do cliente\" in p or \"retenΓ§Γ£o\" in p or \"anΓ‘lise\" in p or \"insight\" in p:\n",
    "        return \"insights\"\n",
    "    # 2. Extraction\n",
    "    elif \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n",
    "        return \"extraction\"\n",
    "    # 3. Push β€” only after insights is ruled out\n",
    "    elif \"notificaΓ§Γ£o push\" in p or \"notificaΓ§Γ£o de reengajamento\" in p:\n",
    "        return \"push\"\n",
    "    else:\n",
    "        return \"sql_qa\"\n",
    "\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# V4.1: ROBUST JSON PARSER (unchanged)\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "def _normalize_pt_decimals(s: str) -> str:\n",
    "    \"\"\"Convert PT-BR decimals (4,5) to JSON-valid (4.5), only outside quoted strings.\"\"\"\n",
    "    result, in_string, escape_next = [], False, False\n",
    "    i = 0\n",
    "    while i < len(s):\n",
    "        c = s[i]\n",
    "        if escape_next:\n",
    "            result.append(c); escape_next = False; i += 1; continue\n",
    "        if c == '\\\\' and in_string:\n",
    "            result.append(c); escape_next = True; i += 1; continue\n",
    "        if c == '\"':\n",
    "            in_string = not in_string; result.append(c); i += 1; continue\n",
    "        if not in_string:\n",
    "            m = re.match(r'(\\d+),(\\d+)', s[i:])\n",
    "            if m:\n",
    "                result.append(m.group(1) + '.' + m.group(2))\n",
    "                i += len(m.group(0)); continue\n",
    "        result.append(c); i += 1\n",
    "    return ''.join(result)\n",
    "\n",
    "\n",
    "def _extract_json(text: str) -> dict | None:\n",
    "    \"\"\"Robust JSON extraction for Portuguese LLM output.\"\"\"\n",
    "    stripped = re.sub(r'^```(?:json)?\\s*|\\s*```$', '', text.strip(), flags=re.MULTILINE).strip()\n",
    "    for attempt in [stripped, _normalize_pt_decimals(stripped)]:\n",
    "        try:\n",
    "            result = json.loads(attempt)\n",
    "            if isinstance(result, dict):\n",
    "                return result\n",
    "        except (json.JSONDecodeError, TypeError):\n",
    "            pass\n",
    "    normalized = _normalize_pt_decimals(stripped)\n",
    "    try:\n",
    "        result = json_repair.repair_json(normalized, return_objects=True)\n",
    "        if isinstance(result, dict):\n",
    "            return result\n",
    "    except Exception:\n",
    "        pass\n",
    "    try:\n",
    "        result = json_repair.repair_json(stripped, return_objects=True)\n",
    "        if isinstance(result, dict):\n",
    "            return result\n",
    "    except Exception:\n",
    "        pass\n",
    "    return None\n",
    "\n",
    "\n",
    "def reward_extraction(completion: str, prompt_text: str = \"\") -> float:\n",
    "    \"\"\"Continuous reward for extraction tasks (max 1.0).\"\"\"\n",
    "    answer = strip_think(completion)\n",
    "    data = _extract_json(answer)\n",
    "\n",
    "    if data is None:\n",
    "        if \"{\" in answer and \"}\" in answer:\n",
    "            return 0.05\n",
    "        return 0.0\n",
    "\n",
    "    if not isinstance(data, dict):\n",
    "        return 0.1\n",
    "\n",
    "    score = 0.3  # valid JSON object\n",
    "\n",
    "    # Schema completeness (0.3 total)\n",
    "    present = sum(1 for f in EXTRACTION_FIELDS if f in data)\n",
    "    score += 0.3 * (present / len(EXTRACTION_FIELDS))\n",
    "\n",
    "    # Value validity (0.4 total)\n",
    "    checks_passed = 0\n",
    "    checks_total = 0\n",
    "\n",
    "    for field, validator in [\n",
    "        (\"sentiment\", lambda v: isinstance(v, str) and v in VALID_SENTIMENTS),\n",
    "        (\"complaint_category\", lambda v: isinstance(v, str) and v in VALID_CATEGORIES),\n",
    "        (\"churn_risk\", lambda v: isinstance(v, str) and v in VALID_CHURN),\n",
    "        (\"repeat_intent\", lambda v: isinstance(v, str) and v in VALID_REPEAT),\n",
    "        # V4.2.1: must be int, not float/bool. PT normalizer turns \"0,5\"β†’0.5 (float)\n",
    "        (\"sentiment_score\", lambda v: isinstance(v, int) and not isinstance(v, bool) and 1 <= v <= 5),\n",
    "    ]:\n",
    "        checks_total += 1\n",
    "        if field in data and validator(data[field]):\n",
    "            checks_passed += 1\n",
    "\n",
    "    for bool_field in (\"delivery_issue\", \"product_issue\", \"seller_issue\", \"would_recommend\"):\n",
    "        checks_total += 1\n",
    "        if bool_field in data and isinstance(data[bool_field], bool):\n",
    "            checks_passed += 1\n",
    "\n",
    "    if checks_total > 0:\n",
    "        score += 0.4 * (checks_passed / checks_total)\n",
    "\n",
    "    # nota=1-2 on a 5-star scale β†’ negative review; nota=4-5 β†’ positive.\n",
    "    # Penalize clear sentiment mismatches to break reward hacking.\n",
    "    import re as _re\n",
    "    nota_match = _re.search(r\"nota=(\\d)/5\", prompt_text)\n",
    "    if nota_match and \"sentiment\" in data:\n",
    "        nota = int(nota_match.group(1))\n",
    "        sentiment = data.get(\"sentiment\", \"\")\n",
    "        if nota <= 2 and sentiment == \"positive\":\n",
    "            score -= 0.20\n",
    "        elif nota >= 4 and sentiment == \"negative\":\n",
    "            score -= 0.20\n",
    "\n",
    "    return max(0.0, min(score, 1.0))\n",
    "\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# V4.2: SQL REWARD V2 β€” Validation-aware (Change 3)\n",
    "# Replaces heuristic vocabulary matching with structural analysis.\n",
    "# Expected: distinguishes \"mentions SQL keywords\" from \"produces correct answer\"\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "def reward_sql_qa(completion: str) -> float:\n",
    "    \"\"\"V4.2: Validation-aware SQL Q&A reward (max 1.0).\n",
    "    \n",
    "    Tier 1 (0.30): SQL structure detected (β‰₯3 keywords or code block)\n",
    "    Tier 2 (0.25): Answer has both query and explanation\n",
    "    Tier 3 (0.25): Numerical specificity (concrete data)\n",
    "    Tier 4 (0.20): Portuguese business domain coherence\n",
    "    \"\"\"\n",
    "    answer = strip_think(completion)\n",
    "    if not answer.strip():\n",
    "        return 0.0\n",
    "\n",
    "    score = 0.0\n",
    "\n",
    "    # Tier 1 (0.30): SQL structure detected\n",
    "    sql_keywords = [\"SELECT\", \"FROM\", \"WHERE\", \"GROUP BY\", \"ORDER BY\",\n",
    "                    \"JOIN\", \"HAVING\", \"COUNT\", \"AVG\", \"SUM\"]\n",
    "    sql_found = sum(1 for kw in sql_keywords if kw in answer.upper())\n",
    "    if sql_found >= 3:\n",
    "        score += 0.30\n",
    "    elif sql_found >= 1:\n",
    "        score += 0.15\n",
    "\n",
    "    # Tier 2 (0.25): Answer has both query AND explanation\n",
    "    has_query = bool(re.search(r\"```sql|SELECT.{5,}FROM\", answer, re.IGNORECASE | re.DOTALL))\n",
    "    has_answer = any(w in answer.lower() for w in [\"resultado\", \"total\", \"mΓ©dia\", \"mostra\", \"portanto\"])\n",
    "    if has_query and has_answer:\n",
    "        score += 0.25\n",
    "    elif has_query or has_answer:\n",
    "        score += 0.12\n",
    "\n",
    "    # Tier 3 (0.25): Numerical specificity\n",
    "    numbers = re.findall(r\"\\d+(?:[.,]\\d+)?(?:\\s*%)?\", answer)\n",
    "    score += min(0.25, 0.05 * len(numbers))\n",
    "\n",
    "    # Tier 4 (0.20): Portuguese business domain coherence β€” EXPANDED (V4.2.1)\n",
    "    pt_domain = [\n",
    "        # Original 10\n",
    "        \"pedidos\", \"clientes\", \"vendedores\", \"produtos\", \"avaliaΓ§Γ£o\",\n",
    "        \"entrega\", \"reclamaΓ§Γ£o\", \"satisfaΓ§Γ£o\", \"categoria\", \"perΓ­odo\",\n",
    "        # V4.2.1: broader e-commerce vocabulary (Cell 8 audit: samples 6, 10)\n",
    "        \"compradores\", \"sentimentos\", \"reclamaΓ§Γ΅es\", \"taxa\", \"distribuiΓ§Γ£o\",\n",
    "        \"vendas\", \"faturamento\", \"estoque\", \"logΓ­stica\", \"marketplace\",\n",
    "        \"consumidores\", \"fornecedores\", \"devoluΓ§Γ΅es\", \"reembolso\", \"frete\",\n",
    "        \"pagamento\", \"cancelamento\", \"atraso\", \"qualidade\", \"nota\",\n",
    "        \"positivos\", \"negativos\", \"neutros\", \"tendΓͺncia\", \"desempenho\",\n",
    "    ]\n",
    "    score += min(0.20, 0.04 * sum(1 for w in pt_domain if w in answer.lower()))\n",
    "\n",
    "    return min(score, 1.0)\n",
    "\n",
    "\n",
    "def reward_insights(completion: str) -> float:\n",
    "    \"\"\"Continuous reward for insights (max 1.0). Unchanged from V4.1.\"\"\"\n",
    "    answer = strip_think(completion)\n",
    "    if not answer.strip():\n",
    "        return 0.0\n",
    "\n",
    "    score = 0.0\n",
    "\n",
    "    action_words = [\"recomend\", \"implement\", \"melhor\", \"reduzir\", \"aumentar\",\n",
    "                    \"priorizar\", \"investir\", \"otimizar\", \"estratΓ©gi\", \"aΓ§Γ£o\"]\n",
    "    matches = sum(1 for w in action_words if w in answer.lower())\n",
    "    score += min(0.4, 0.08 * matches)\n",
    "\n",
    "    length = len(answer)\n",
    "    if 100 <= length <= 800:\n",
    "        score += 0.3\n",
    "    elif length > 0:\n",
    "        score += 0.3 * max(0, 1 - abs(length - 450) / 450)\n",
    "\n",
    "    structure_marks = len(re.findall(r\"^[-β€’*]\\s|^\\d+[.)]\\s|^#{1,3}\\s\", answer, re.MULTILINE))\n",
    "    score += min(0.2, 0.04 * structure_marks)\n",
    "\n",
    "    if any(w in answer.lower() for w in [\"cliente\", \"produto\", \"serviΓ§o\", \"empresa\"]):\n",
    "        score += 0.1\n",
    "\n",
    "    return min(score, 1.0)\n",
    "\n",
    "\n",
    "def reward_push(completion: str) -> float:\n",
    "    \"\"\"Continuous reward for push notifications (max 1.0).\n",
    "    \n",
    "    V4.2.1 fixes (Cell 8 audit):\n",
    "    - Steep length penalty: hard zero above 200 chars (was linear decay to 240)\n",
    "    - Formal email penalty: -0.20 for \"Prezado\"/\"Atenciosamente\"/etc.\n",
    "    \"\"\"\n",
    "    answer = strip_think(completion).strip()\n",
    "    if not answer:\n",
    "        return 0.0\n",
    "\n",
    "    length = len(answer)\n",
    "\n",
    "    # Length score (0.50 max) β€” steep decay above 120 chars\n",
    "    if length <= 120:\n",
    "        length_score = 0.50\n",
    "    elif length <= 160:\n",
    "        length_score = 0.50 - 0.40 * ((length - 120) / 40)  # 0.50 β†’ 0.10\n",
    "    elif length <= 200:\n",
    "        length_score = 0.10 - 0.10 * ((length - 160) / 40)  # 0.10 β†’ 0.00\n",
    "    else:\n",
    "        length_score = 0.0\n",
    "\n",
    "    pt_markers = re.findall(r\"[ãçéΓͺΓ³ΓΊΓ’Γ΅]|vocΓͺ|para|como|seu|sua|oferta|desconto|produto\",\n",
    "                            answer, re.IGNORECASE)\n",
    "    lang_score = min(0.3, 0.03 * len(pt_markers))\n",
    "\n",
    "    generic = [\"olΓ‘\", \"obrigado pela compra\", \"agradecemos\"]\n",
    "    is_generic = any(g in answer.lower() for g in generic)\n",
    "    creativity_score = 0.0 if is_generic else 0.2\n",
    "\n",
    "    # Formal email penalty β€” push notifications should NOT be formal emails\n",
    "    formal_markers = [\n",
    "        \"prezado\", \"prezada\", \"atenciosamente\", \"cordialmente\",\n",
    "        \"att,\", \"att.\", \"respeitosamente\", \"caro cliente\", \"cara cliente\",\n",
    "    ]\n",
    "    has_formal = any(fm in answer.lower() for fm in formal_markers)\n",
    "    formal_penalty = -0.20 if has_formal else 0.0\n",
    "\n",
    "    return max(0.0, min(length_score + lang_score + creativity_score + formal_penalty, 1.0))\n",
    "\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# V4.2: GDPO PER-COMPONENT NORMALIZATION (Change 5)\n",
    "# Normalize each reward component independently before aggregation.\n",
    "# GDPO (2601.05242) Β§3.1: preserves ~4Γ— more distinct advantage groups.\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "def gdpo_normalize(component_rewards: dict) -> list:\n",
    "    \"\"\"Per-component normalization before aggregation (GDPO 2601.05242 Β§3.1).\n",
    "    \n",
    "    Args:\n",
    "        component_rewards: {task_name: [reward_per_sample, ...]} for each component\n",
    "    \n",
    "    Returns:\n",
    "        List of normalized summed rewards, one per sample.\n",
    "    \"\"\"\n",
    "    normalized = {}\n",
    "    for task, rewards in component_rewards.items():\n",
    "        rewards_t = torch.tensor(rewards, dtype=torch.float32)\n",
    "        std = rewards_t.std()\n",
    "        if std > 1e-8:\n",
    "            normalized[task] = ((rewards_t - rewards_t.mean()) / std).tolist()\n",
    "        else:\n",
    "            normalized[task] = [0.0] * len(rewards)  # zero-variance group\n",
    "    # Sum normalized components per sample\n",
    "    n = len(next(iter(normalized.values())))\n",
    "    return [sum(normalized[t][i] for t in normalized) for i in range(n)]\n",
    "\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# V4.2: DYNAMIC TASK WEIGHTING β€” MT-GRPO IWU (Change 6)\n",
    "# Track per-task reward improvement rates, upweight stagnating tasks.\n",
    "# MT-GRPO (2602.05547) Β§3.2: prevents easy-task collapse.\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "_task_weights = {\n",
    "    \"extraction\": 0.40,   # matches training data distribution (40%)\n",
    "    \"sql_qa\":     0.40,   # matches training data distribution (40%)\n",
    "    \"insights\":   0.10,   # matches training data distribution (10%)\n",
    "    \"push\":       0.10,   # matches training data distribution (10%)\n",
    "}\n",
    "_task_reward_history = {t: [] for t in _task_weights}\n",
    "\n",
    "def update_task_weights(step: int, per_task_rewards: dict, update_interval: int = 50):\n",
    "    \"\"\"MT-GRPO IWU: update task sampling weights based on improvement rate.\n",
    "    \n",
    "    Args:\n",
    "        step: Current training step\n",
    "        per_task_rewards: {task: mean_reward} from latest eval\n",
    "        update_interval: Only update every N steps\n",
    "    \"\"\"\n",
    "    global _task_weights\n",
    "    if step % update_interval != 0 or step == 0:\n",
    "        return\n",
    "    \n",
    "    for task, reward in per_task_rewards.items():\n",
    "        if task not in _task_reward_history:\n",
    "            continue\n",
    "        _task_reward_history[task].append(reward)\n",
    "        if len(_task_reward_history[task]) >= 2:\n",
    "            improvement = _task_reward_history[task][-1] - _task_reward_history[task][-2]\n",
    "            if improvement < 0.01:         # stagnating\n",
    "                _task_weights[task] = min(0.50, _task_weights[task] * 1.3)\n",
    "            elif improvement > 0.05:       # improving fast\n",
    "                _task_weights[task] = max(0.10, _task_weights[task] * 0.85)\n",
    "    \n",
    "    # Normalize to sum to 1\n",
    "    total = sum(_task_weights.values())\n",
    "    _task_weights = {t: w / total for t, w in _task_weights.items()}\n",
    "\n",
    "\n",
    "def get_task_weighted_indices(dataset, n_samples: int) -> list:\n",
    "    \"\"\"Sample indices with probability proportional to task weights.\"\"\"\n",
    "    task_indices = {t: [] for t in _task_weights}\n",
    "    for i, record in enumerate(dataset):\n",
    "        user_txt = \" \".join(m[\"content\"] for m in record[\"prompt\"] if m[\"role\"] == \"user\")\n",
    "        task = _classify_task_type(user_txt)\n",
    "        if task in task_indices:\n",
    "            task_indices[task].append(i)\n",
    "    \n",
    "    sampled = []\n",
    "    for task, weight in _task_weights.items():\n",
    "        n = max(1, int(n_samples * weight))\n",
    "        pool = task_indices.get(task, [])\n",
    "        if pool:\n",
    "            sampled.extend(random.sample(pool, min(n, len(pool))))\n",
    "    random.shuffle(sampled)\n",
    "    return sampled[:n_samples]\n",
    "\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# MASTER REWARD FUNCTION β€” V4.2: returns per-component rewards for GDPO\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "def commerce_reward_fn(completions, prompts, **kwargs) -> list:\n",
    "    \"\"\"Master reward function with GDPO normalization + dynamic task weighting.\n",
    "    \n",
    "    V4.2 integration with TRL 0.24.0:\n",
    "    TRL calls this once per step with the full batch (batch_size Γ— G completions).\n",
    "    We exploit this to apply batch-level per-component normalization (GDPO Β§3.1)\n",
    "    and dynamic task weighting (MT-GRPO IWU Β§3.2) INSIDE the reward function,\n",
    "    so the trainer receives pre-normalized, weighted rewards without modification.\n",
    "    \n",
    "    Pipeline:\n",
    "      1. Score each completion with its task-specific reward function (raw)\n",
    "      2. Group raw rewards by task type\n",
    "      3. GDPO: z-score normalize each task group independently\n",
    "      4. IWU: multiply normalized rewards by current _task_weights\n",
    "      5. Shift back to [0, 1] range (GRPO with scale_rewards=False expects non-negative)\n",
    "      6. Return flat list in original sample order\n",
    "    \"\"\"\n",
    "    n = len(completions)\n",
    "    raw_rewards = [0.0] * n\n",
    "    task_labels = [\"\"] * n\n",
    "    \n",
    "    # ── Step 1: Compute raw per-sample rewards ──────────────────────────────\n",
    "    for i, (completion, prompt) in enumerate(zip(completions, prompts)):\n",
    "        if isinstance(completion, list):\n",
    "            comp_text = completion[-1][\"content\"] if completion else \"\"\n",
    "        else:\n",
    "            comp_text = str(completion)\n",
    "\n",
    "        if isinstance(prompt, list):\n",
    "            prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
    "        else:\n",
    "            prompt_text = str(prompt)\n",
    "\n",
    "        task = _classify_task_type(prompt_text)\n",
    "        task_labels[i] = task\n",
    "\n",
    "        if task == \"extraction\":\n",
    "            raw_rewards[i] = reward_extraction(comp_text, prompt_text)\n",
    "        elif task == \"sql_qa\":\n",
    "            raw_rewards[i] = reward_sql_qa(comp_text)\n",
    "        elif task == \"insights\":\n",
    "            raw_rewards[i] = reward_insights(comp_text)\n",
    "        elif task == \"push\":\n",
    "            raw_rewards[i] = reward_push(comp_text)\n",
    "        else:\n",
    "            raw_rewards[i] = 0.2 if comp_text.strip() else 0.0\n",
    "\n",
    "    # ── Step 2-4: GDPO per-component normalization + IWU weighting ──────────\n",
    "    # Group indices by task\n",
    "    task_indices = {}\n",
    "    for i, task in enumerate(task_labels):\n",
    "        if task not in task_indices:\n",
    "            task_indices[task] = []\n",
    "        task_indices[task].append(i)\n",
    "    \n",
    "    final_rewards = [0.0] * n\n",
    "    \n",
    "    for task, indices in task_indices.items():\n",
    "        task_raw = [raw_rewards[i] for i in indices]\n",
    "        \n",
    "        # GDPO: z-score normalize within this task group\n",
    "        if len(task_raw) > 1:\n",
    "            t_mean = sum(task_raw) / len(task_raw)\n",
    "            t_var = sum((r - t_mean) ** 2 for r in task_raw) / (len(task_raw) - 1)\n",
    "            t_std = t_var ** 0.5\n",
    "            if t_std > 1e-8:\n",
    "                normed = [(r - t_mean) / t_std for r in task_raw]\n",
    "            else:\n",
    "                normed = [0.0] * len(task_raw)\n",
    "        else:\n",
    "            # Single sample in this task group β€” can't normalize, use raw\n",
    "            normed = [0.0]\n",
    "        \n",
    "        # IWU: scale by dynamic task weight\n",
    "        weight = _task_weights.get(task, 0.25)\n",
    "        weighted = [v * weight for v in normed]\n",
    "        \n",
    "        for idx_in_group, global_idx in enumerate(indices):\n",
    "            final_rewards[global_idx] = weighted[idx_in_group]\n",
    "    \n",
    "    # ── Step 5: Shift to non-negative range ─────────────────────────────────\n",
    "    # GRPO with scale_rewards=False computes advantages as reward - mean(rewards).\n",
    "    # Normalized rewards are already zero-centered per-task, so the advantage\n",
    "    # computation will work correctly. But TRL may log negative rewards as warnings.\n",
    "    # Shift so minimum is 0 to keep logging clean, without changing advantage ordering.\n",
    "    min_r = min(final_rewards) if final_rewards else 0.0\n",
    "    if min_r < 0:\n",
    "        final_rewards = [r - min_r for r in final_rewards]\n",
    "    \n",
    "    return final_rewards\n",
    "\n",
    "\n",
    "def commerce_reward_fn_raw(completions, prompts, **kwargs) -> list:\n",
    "    \"\"\"Raw reward function WITHOUT GDPO/IWU β€” used for eval metrics.\n",
    "    \n",
    "    Eval should report raw task-specific rewards for interpretability.\n",
    "    The GDPO+IWU normalization is only for shaping the training gradient signal.\n",
    "    \"\"\"\n",
    "    rewards = []\n",
    "    for completion, prompt in zip(completions, prompts):\n",
    "        if isinstance(completion, list):\n",
    "            comp_text = completion[-1][\"content\"] if completion else \"\"\n",
    "        else:\n",
    "            comp_text = str(completion)\n",
    "\n",
    "        if isinstance(prompt, list):\n",
    "            prompt_text = \" \".join(m.get(\"content\", \"\") for m in prompt)\n",
    "        else:\n",
    "            prompt_text = str(prompt)\n",
    "\n",
    "        task = _classify_task_type(prompt_text)\n",
    "\n",
    "        if task == \"extraction\":\n",
    "            rewards.append(reward_extraction(comp_text, prompt_text))\n",
    "        elif task == \"sql_qa\":\n",
    "            rewards.append(reward_sql_qa(comp_text))\n",
    "        elif task == \"insights\":\n",
    "            rewards.append(reward_insights(comp_text))\n",
    "        elif task == \"push\":\n",
    "            rewards.append(reward_push(comp_text))\n",
    "        else:\n",
    "            r = 0.2 if comp_text.strip() else 0.0\n",
    "            rewards.append(r)\n",
    "    return rewards\n",
    "\n",
    "\n",
    "print(\"βœ“ Reward functions defined (V4.2: SQL v2 + GDPO active + IWU active)\")\n",
    "print(f\"  Task weights: {_task_weights}\")\n",
    "print(f\"  commerce_reward_fn: GDPO+IWU normalized (for training)\")\n",
    "print(f\"  commerce_reward_fn_raw: raw scores (for eval/audit)\")\n",
    "print(f\"  Task weights: {_task_weights}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 8: Reward Function Audit (Change 2)\n\n**V4.2 addition: 30-minute audit protocol.**\n\nGenerate 20 completions (5 per task) at temp=0.1, score them with the reward function,\nthen have the human score them 0-10. Compute Spearman ρ.\n\n**Gate:** ρ > 0.70. If below, reward function is miscalibrated β€” fix before training.\n\n**Why:** The V1-V4 parser bug would have been caught in 30 minutes with this protocol.\n\n### Instructions\n1. Run this cell β€” it generates 20 completions and scores them automatically\n2. For each completion: read the FULL output (no truncation), then enter your 0-10 score at the prompt\n3. After all 20 scores, the cell computes Spearman ρ automatically\n4. If ρ < 0.70, investigate discrepancies (marked ⚠️) before proceeding to training"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import spearmanr\n",
    "\n",
    "AUDIT_PROMPTS_PER_TASK = 5\n",
    "\n",
    "# ── Collect audit prompts (5 per task) ───────────────────────────────────────\n",
    "audit_by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
    "with open(TRAIN_FILE) as f:\n",
    "    for line in f:\n",
    "        row = json.loads(line)\n",
    "        convs = row[\"conversations\"]\n",
    "        prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
    "        if not prompt_msgs:\n",
    "            continue\n",
    "        user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
    "        task = _classify_task_type(user_text)\n",
    "        if len(audit_by_type[task]) < AUDIT_PROMPTS_PER_TASK:\n",
    "            audit_by_type[task].append(prompt_msgs)\n",
    "\n",
    "print(f\"Audit prompts collected: {', '.join(f'{k}={len(v)}' for k, v in audit_by_type.items())}\")\n",
    "\n",
    "# ── Generate completions and score automatically ─────────────────────────────\n",
    "FastLanguageModel.for_inference(model)\n",
    "\n",
    "audit_auto_scores = []\n",
    "audit_tasks = []\n",
    "audit_completions = []\n",
    "\n",
    "audit_prompts_text = []  # store original user message for display\n",
    "\n",
    "for task_type in [\"extraction\", \"sql_qa\", \"insights\", \"push\"]:\n",
    "    for msgs in audit_by_type[task_type]:\n",
    "        # Extract original user message BEFORE injecting system prompt\n",
    "        user_content = \"\\n\".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n",
    "        audit_prompts_text.append(user_content)\n",
    "        \n",
    "        msgs = inject_task_system_prompt(msgs, task_type)\n",
    "        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
    "        inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
    "        with torch.no_grad():\n",
    "            out = model.generate(\n",
    "                **inputs,\n",
    "                max_new_tokens=MAX_COMPLETION_LENGTH,\n",
    "                temperature=0.1,  # near-deterministic for audit\n",
    "                do_sample=True,\n",
    "                repetition_penalty=1.0,\n",
    "            )\n",
    "        resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
    "        r = commerce_reward_fn_raw([resp], [text])[0]  # Raw rewards for audit (not GDPO-normalized)\n",
    "        audit_auto_scores.append(r)\n",
    "        audit_tasks.append(task_type)\n",
    "        audit_completions.append(resp)\n",
    "\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "# INTERACTIVE REWARD AUDIT\n",
    "# Shows each completion in FULL (no truncation), prompts for a 0-10 score.\n",
    "# ══════════════════════════════════════════════════════════════════════════════\n",
    "\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(\"REWARD FUNCTION AUDIT β€” 20 Completions (interactive scoring)\")\n",
    "print(\"Score each completion 0-10:  0=garbage, 5=acceptable, 10=perfect\")\n",
    "print(f\"{'='*80}\")\n",
    "\n",
    "audit_human_scores = []\n",
    "\n",
    "for i, (task, auto_r, comp, prompt_txt) in enumerate(zip(audit_tasks, audit_auto_scores, audit_completions, audit_prompts_text)):\n",
    "    answer = strip_think(comp)  # full completion, no truncation\n",
    "    print(f\"\\n{'─'*80}\")\n",
    "    print(f\"  Sample {i+1}/{len(audit_auto_scores)} [{task}]  auto_reward={auto_r:.3f}\")\n",
    "    print(f\"{'─'*80}\")\n",
    "    print(f\"\\nINPUT REVIEW:\\n{prompt_txt}\\n\")\n",
    "    print(f\"MODEL OUTPUT:\\n{answer}\")\n",
    "    print()\n",
    "    while True:\n",
    "        try:\n",
    "            score = float(input(f\"  Your score (0-10): \"))\n",
    "            if 0 <= score <= 10:\n",
    "                break\n",
    "            print(\"  ⚠️ Score must be between 0 and 10\")\n",
    "        except (ValueError, EOFError):\n",
    "            print(\"  ⚠️ Enter a number between 0 and 10\")\n",
    "    audit_human_scores.append(score)\n",
    "    print(f\"  β†’ Recorded: human={score:.0f}, auto={auto_r:.3f}\")\n",
    "\n",
    "# ── Compute Spearman ρ ───────────────────────────────────────────────────────\n",
    "human_normalized = [s / 10.0 for s in audit_human_scores]\n",
    "rho, p_value = spearmanr(human_normalized, audit_auto_scores)\n",
    "\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(f\"AUDIT RESULTS\")\n",
    "print(f\"{'='*80}\")\n",
    "print(f\"  Spearman ρ = {rho:.3f}  (p = {p_value:.4f})\")\n",
    "print()\n",
    "print(f\"  {'#':>3s}  {'Task':12s}  {'Human':>6s}  {'Auto':>6s}  {'Ξ”':>6s}\")\n",
    "print(f\"  {'─'*40}\")\n",
    "for i, (task, h, a) in enumerate(zip(audit_tasks, human_normalized, audit_auto_scores)):\n",
    "    delta = abs(h - a)\n",
    "    flag = \" ⚠️\" if delta > 0.3 else \"\"\n",
    "    print(f\"  {i+1:3d}  {task:12s}  {h:6.2f}  {a:6.3f}  {delta:6.3f}{flag}\")\n",
    "\n",
    "if rho > 0.70:\n",
    "    print(f\"\\n  βœ… PASS: ρ={rho:.3f} > 0.70 β€” reward function is calibrated\")\n",
    "else:\n",
    "    print(f\"\\n  ❌ FAIL: ρ={rho:.3f} < 0.70 β€” reward function is miscalibrated\")\n",
    "    print(\"  β†’ Investigate samples marked ⚠️ before training. Check:\")\n",
    "    print(\"    1. Is the JSON parser handling all output formats?\")\n",
    "    print(\"    2. Are SQL reward tiers appropriate for this model's output style?\")\n",
    "    print(\"    3. Are insights/push length penalties calibrated?\")\n",
    "\n",
    "assert rho > 0.65, f\"Reward function miscalibrated (ρ={rho:.3f} < 0.65). Fix before training.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 9: Build Stratified Eval Set (Change 1)\n\n**V4.2: 65 stratified samples** (20 extraction + 15 sql_qa + 15 insights + 15 push).\n\nSampled from `data/pairs/eval.jsonl`, saved as `data/pairs/eval_v2_stratified.jsonl`.\n**This file is fixed across all seeds.** Never resample.\n\n**Gate:** Exactly 65 samples with correct per-task counts."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "eval_v2_stratified_path = DATA_DIR / \"pairs\" / \"eval_v2_stratified.jsonl\"\n\n# ── Check if already built (idempotent across seeds) ────────────────────────\nif eval_v2_stratified_path.exists():\n    existing = []\n    with open(eval_v2_stratified_path) as f:\n        for line in f:\n            existing.append(json.loads(line))\n    # Verify counts\n    task_counts = {}\n    for rec in existing:\n        task_counts[rec[\"task_type\"]] = task_counts.get(rec[\"task_type\"], 0) + 1\n    print(f\"βœ“ Stratified eval set already exists: {eval_v2_stratified_path}\")\n    print(f\"  Counts: {task_counts}\")\n    print(f\"  Total: {len(existing)}\")\n    assert len(existing) == EVAL_TOTAL, f\"Expected {EVAL_TOTAL}, got {len(existing)}\"\nelse:\n    # ── Build from eval.jsonl (or train.jsonl fallback) ──────────────────────\n    eval_source = EVAL_FILE if EVAL_FILE.exists() else TRAIN_FILE\n    print(f\"Building stratified eval set from: {eval_source}\")\n    \n    # Collect all records by task\n    eval_by_task = {t: [] for t in EVAL_SAMPLES_PER_TASK}\n    with open(eval_source) as f:\n        for line in f:\n            row = json.loads(line)\n            convs = row[\"conversations\"]\n            prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n            if not prompt_msgs:\n                continue\n            user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n            task = _classify_task_type(user_text)\n            if task in eval_by_task:\n                eval_by_task[task].append({\n                    \"conversations\": convs,\n                    \"prompt_msgs\": prompt_msgs,\n                    \"task_type\": task,\n                })\n    \n    print(f\"  Available: {', '.join(f'{k}={len(v)}' for k, v in eval_by_task.items())}\")\n    \n    # Stratified sampling with FIXED seed (42 always, regardless of CURRENT_SEED)\n    eval_rng = random.Random(42)\n    stratified = []\n    for task, target_n in EVAL_SAMPLES_PER_TASK.items():\n        pool = eval_by_task[task]\n        if len(pool) < target_n:\n            print(f\"  ⚠️ {task}: only {len(pool)} available, wanted {target_n}. Using all.\")\n            sampled = pool\n        else:\n            sampled = eval_rng.sample(pool, target_n)\n        stratified.extend(sampled)\n    \n    # Save as JSONL\n    eval_v2_stratified_path.parent.mkdir(parents=True, exist_ok=True)\n    with open(eval_v2_stratified_path, \"w\") as f:\n        for rec in stratified:\n            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n    \n    # Verify\n    task_counts = {}\n    for rec in stratified:\n        task_counts[rec[\"task_type\"]] = task_counts.get(rec[\"task_type\"], 0) + 1\n    \n    print(f\"\\nβœ“ Stratified eval set saved: {eval_v2_stratified_path}\")\n    print(f\"  Counts: {task_counts}\")\n    print(f\"  Total: {len(stratified)}\")\n\nprint(f\"\\nExpected: {EVAL_SAMPLES_PER_TASK} = {EVAL_TOTAL} total\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 10: Dataset Preparation + DynamicTaskSampler Init\n\n**V4.2 changes:**\n- Eval loaded from `eval_v2_stratified.jsonl` (65 fixed samples) instead of random split\n- Train still from `train.jsonl` with task-aware system prompt injection\n- DynamicTaskSampler initialized for IWU (Change 6)\n\n**Gate:** Train has ~1,480+ prompts. Eval has exactly 65 stratified samples. All 4 task types present in both."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from datasets import Dataset\n\ndef prepare_datasets_v42(train_file, eval_stratified_file, seed=CURRENT_SEED):\n    \"\"\"V4.2: Load train from JSONL + eval from stratified file.\"\"\"\n    rng = random.Random(seed)\n\n    # ── Train set ────────────────────────────────────────────────────────────\n    train_records = []\n    with open(train_file) as f:\n        for line in f:\n            row = json.loads(line)\n            convs = row[\"conversations\"]\n            prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n            if prompt_msgs:\n                train_records.append(prompt_msgs)\n    rng.shuffle(train_records)\n    \n    # Inject task-aware system prompts\n    for i, msgs in enumerate(train_records):\n        user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n        task = _classify_task_type(user_text)\n        train_records[i] = inject_task_system_prompt(msgs, task)\n    \n    # ── Eval set (V4.2: from stratified file, fixed across seeds) ────────────\n    eval_records = []\n    with open(eval_stratified_file) as f:\n        for line in f:\n            rec = json.loads(line)\n            prompt_msgs = rec[\"prompt_msgs\"]\n            user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n            task = _classify_task_type(user_text)\n            eval_records.append(inject_task_system_prompt(prompt_msgs, task))\n    \n    print(\"  βœ“ Task-aware system prompts injected\")\n    \n    # Print distributions\n    for label, records in [(\"train\", train_records), (\"eval\", eval_records)]:\n        dist = {}\n        for msgs in records:\n            user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n            task = _classify_task_type(user_text)\n            dist[task] = dist.get(task, 0) + 1\n        print(f\"  {label}: {len(records)} prompts β€” {dist}\")\n    \n    train_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in train_records])\n    eval_ds = Dataset.from_list([{\"prompt\": msgs} for msgs in eval_records])\n    return train_ds, eval_ds\n\ntrain_dataset, eval_dataset = prepare_datasets_v42(TRAIN_FILE, eval_v2_stratified_path)\nprint(f\"\\nβœ“ Datasets: train={len(train_dataset)}, eval={len(eval_dataset)}\")\nassert len(eval_dataset) == EVAL_TOTAL, f\"Expected {EVAL_TOTAL} eval samples, got {len(eval_dataset)}\"\nprint(f\"βœ“ Eval is stratified: {EVAL_TOTAL} samples (fixed across seeds)\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 11: Smoke Test (1 Step)\n\n**Gate:** No OOM. Peak VRAM < 20GB. Step time < 180s."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "from trl import GRPOConfig, GRPOTrainer\n\nFastLanguageModel.for_training(model)\n\nsmoke_config = GRPOConfig(\n    output_dir=str(CHECKPOINT_DIR / \"smoke\"),\n    num_generations=NUM_GENERATIONS,\n    scale_rewards=SCALE_REWARDS,\n    max_completion_length=MAX_COMPLETION_LENGTH,\n    max_steps=1,\n    temperature=TEMPERATURE,\n    beta=BETA,\n    per_device_train_batch_size=BATCH_SIZE,\n    gradient_accumulation_steps=1,\n    learning_rate=LEARNING_RATE,\n    lr_scheduler_type=LR_SCHEDULER_TYPE,\n    warmup_ratio=WARMUP_RATIO,\n    fp16=False,\n    bf16=True,\n    logging_steps=1,\n    save_steps=999,\n    report_to=\"none\",\n    max_prompt_length=MAX_SEQ_LENGTH // 2,\n    seed=CURRENT_SEED,\n    remove_unused_columns=False,\n)\n\n\nclass UnslothGRPOTrainer(GRPOTrainer):\n    def _generate(self, prompts, images):\n        FastLanguageModel.for_inference(self.model)\n        try:\n            result = super()._generate(prompts, images)\n        finally:\n            FastLanguageModel.for_training(self.model)\n        return result\n\n\nsmoke_trainer = UnslothGRPOTrainer(\n    model=model,\n    reward_funcs=commerce_reward_fn,\n    args=smoke_config,\n    train_dataset=train_dataset,\n    processing_class=tokenizer,\n)\n\nt0 = time.time()\nsmoke_trainer.train()\nstep_time = time.time() - t0\n\npeak_vram = torch.cuda.max_memory_allocated() / 1e9\nprint(f\"\\nβœ“ Smoke test passed!\")\nprint(f\"  Step time: {step_time:.0f}s\")\nprint(f\"  Peak VRAM: {peak_vram:.1f}GB / {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB\")\nprint(f\"  Estimated full run ({MAX_STEPS} steps): {step_time * MAX_STEPS / 3600:.1f}h\")\n\ndel smoke_trainer\ngc.collect(); torch.cuda.empty_cache()"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 12: Probe Run (10 Steps)\n\n**V4.2:** Uses `CURRENT_SEED` for reproducibility. No hard clip_ratio gate (expected=0 for LoRA)."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FastLanguageModel.for_training(model)\n",
    "\n",
    "probe_config = GRPOConfig(\n",
    "    output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
    "    num_generations=NUM_GENERATIONS,\n",
    "    scale_rewards=SCALE_REWARDS,\n",
    "    max_completion_length=MAX_COMPLETION_LENGTH,\n",
    "    max_steps=10,\n",
    "    temperature=TEMPERATURE,\n",
    "    beta=BETA,\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    gradient_accumulation_steps=GRAD_ACCUM,\n",
    "    learning_rate=LEARNING_RATE,\n",
    "    lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
    "    warmup_ratio=WARMUP_RATIO,\n",
    "    fp16=False,\n",
    "    bf16=True,\n",
    "    logging_steps=1,\n",
    "    save_steps=999,\n",
    "    report_to=\"none\",\n",
    "    max_prompt_length=MAX_SEQ_LENGTH // 2,\n",
    "    seed=CURRENT_SEED,\n",
    "    remove_unused_columns=False,\n",
    ")\n",
    "\n",
    "probe_trainer = UnslothGRPOTrainer(\n",
    "    model=model,\n",
    "    reward_funcs=commerce_reward_fn,\n",
    "    args=probe_config,\n",
    "    train_dataset=train_dataset,\n",
    "    processing_class=tokenizer,\n",
    ")\n",
    "\n",
    "t0 = time.time()\n",
    "result = probe_trainer.train()\n",
    "elapsed = time.time() - t0\n",
    "\n",
    "# ── Extract metrics from log history ─────────────────────────────────────────\n",
    "# V4.2.1: TRL 0.24.0 logs under \"reward\" / \"rewards/commerce_reward_fn/mean\"\n",
    "# and \"grad_norm\" (no \"train/\" prefix in log_history entries).\n",
    "rewards = []\n",
    "reward_stds = []\n",
    "grad_norms = []\n",
    "for entry in probe_trainer.state.log_history:\n",
    "    if \"rewards/commerce_reward_fn/mean\" in entry:\n",
    "        rewards.append(entry[\"rewards/commerce_reward_fn/mean\"])\n",
    "    if \"rewards/commerce_reward_fn/std\" in entry:\n",
    "        reward_stds.append(entry[\"rewards/commerce_reward_fn/std\"])\n",
    "    if \"grad_norm\" in entry:\n",
    "        grad_norms.append(entry[\"grad_norm\"])\n",
    "\n",
    "print(f\"\\n{'='*60}\")\n",
    "print(f\"PROBE RESULTS ({elapsed:.0f}s, {elapsed/10:.0f}s/step)\")\n",
    "print(f\"  Rewards:     {[f'{r:.3f}' for r in rewards]}\")\n",
    "print(f\"  Reward stds: {[f'{s:.3f}' for s in reward_stds]}\")\n",
    "print(f\"  Grad norms:  {[f'{g:.4f}' for g in grad_norms]}\")\n",
    "print(f\"  Train loss:  {result.training_loss:.4f}\")\n",
    "print(f\"{'='*60}\")\n",
    "\n",
    "if rewards and max(rewards) > 0:\n",
    "    print(\"βœ“ Model generates scoreable output\")\n",
    "else:\n",
    "    print(\"βœ— WARNING: All rewards are 0. Check reward functions.\")\n",
    "\n",
    "if grad_norms and max(grad_norms) > 0:\n",
    "    print(\"βœ“ Gradients are flowing\")\n",
    "else:\n",
    "    print(\"βœ— WARNING: All grad_norms are 0. Check model/LoRA setup.\")\n",
    "\n",
    "if reward_stds and min(reward_stds) > 0:\n",
    "    print(\"βœ“ Batches have reward variance (GRPO has signal)\")\n",
    "elif reward_stds:\n",
    "    n_zero = sum(1 for s in reward_stds if s < 1e-6)\n",
    "    print(f\"⚠️ WARNING: {n_zero}/{len(reward_stds)} steps had zero reward std. Consider increasing G.\")\n",
    "else:\n",
    "    print(\"⚠️ WARNING: No reward_std logged. Check TRL version.\")\n",
    "\n",
    "print(\"\\n→ Proceed to full training (Cell 13)\")\n",
    "\n",
    "del probe_trainer\n",
    "gc.collect(); torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 13: W&B Init + Full Training (1,500 Steps)\n\n**V4.2 changes:**\n- **MAX_STEPS=1,500** (multi-epoch, ~2.5Γ— full dataset) (Change 4)\n- **EvalRewardCallback v2:** 65 stratified samples, per-task 95% CIs, GDPO normalization logging, dynamic task weight updates, **best checkpoint saving** (Changes 1, 5, 6, 8)\n- **`SAVE_STEPS=100`, `EVAL_STEPS=50`** (scaled for longer run)\n- **Seed in W&B config** for multi-seed tracking (Change 7)\n- **Best checkpoint saved explicitly** when eval improves (Change 8)"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import wandb\nfrom transformers import TrainerCallback\n\nwandb.login()\nwandb.init(\n    project=WANDB_PROJECT,\n    name=f\"grpo-v4.2-instruct-0.5B-seed{CURRENT_SEED}-{time.strftime('%Y%m%d-%H%M')}\",\n    config={\n        \"model_id\": MODEL_ID,\n        \"version\": \"v4.2\",\n        \"seed\": CURRENT_SEED,\n        \"seeds_planned\": SEEDS,\n        \"num_generations\": NUM_GENERATIONS,\n        \"max_completion_length\": MAX_COMPLETION_LENGTH,\n        \"temperature\": TEMPERATURE,\n        \"learning_rate\": LEARNING_RATE,\n        \"lr_scheduler_type\": LR_SCHEDULER_TYPE,\n        \"warmup_ratio\": WARMUP_RATIO,\n        \"beta\": BETA,\n        \"scale_rewards\": SCALE_REWARDS,\n        \"batch_size\": BATCH_SIZE,\n        \"grad_accum\": GRAD_ACCUM,\n        \"max_steps\": MAX_STEPS,\n        \"lora_r\": LORA_R,\n        \"lora_alpha\": LORA_ALPHA,\n        \"train_prompts\": len(train_dataset),\n        \"eval_prompts\": len(eval_dataset),\n        \"eval_stratified\": True,\n        \"eval_per_task\": EVAL_SAMPLES_PER_TASK,\n        \"repetition_penalty_override\": 1.0,\n        \"json_parser\": \"json-repair + PT-BR decimal normalizer\",\n        \"sql_reward\": \"v2 (validation-aware, 4-tier)\",\n        \"gdpo_normalization\": True,\n        \"dynamic_task_weighting\": \"MT-GRPO IWU\",\n        \"changes_from_v41\": \"stratified eval 65, reward audit, SQL v2, 1500 steps, GDPO, IWU, 3 seeds, best ckpt\",\n    },\n)\nprint(f\"βœ“ W&B run: {wandb.run.url}\")\n\n\n# ══════════════════════════════════════════════════════════════════════════════\n# V4.2: EvalRewardCallback v2\n# - Uses 65 stratified eval samples (Change 1)\n# - Reports per-task means with 95% CIs (Change 1)\n# - Runs GDPO normalization and logs component stats (Change 5)\n# - Updates dynamic task weights via IWU (Change 6)\n# - Saves best checkpoint explicitly (Change 8)\n# ══════════════════════════════════════════════════════════════════════════════\n\nclass EvalRewardCallbackV2(TrainerCallback):\n    def __init__(self, eval_records, reward_fn, patience, delta):\n        self.eval_records = eval_records\n        self.reward_fn = reward_fn\n        self.patience = patience\n        self.delta = delta\n        self.best_reward = -float(\"inf\")\n        self.best_step = 0\n        self.no_improve_count = 0\n\n    def on_step_end(self, args, state, control, model=None, processing_class=None, **kwargs):\n        if state.global_step == 0 or state.global_step % EVAL_STEPS != 0:\n            return control\n\n        tokenizer_local = processing_class\n        if tokenizer_local is None:\n            print(\"[EvalRewardCallback] WARNING: tokenizer is None, skipping eval\")\n            return control\n\n        mean_reward, per_task, per_task_all = self._run_eval(model, tokenizer_local, args)\n        improved = mean_reward > self.best_reward + self.delta\n\n        # ── Per-task 95% CIs (Change 1) ──────────────────────────────────────\n        log_data = {\n            \"eval/mean_reward\": mean_reward,\n            \"eval/best_reward\": max(self.best_reward, mean_reward),\n            \"eval/no_improve_count\": self.no_improve_count,\n        }\n        \n        ci_strs = []\n        for task_name, task_rewards in per_task_all.items():\n            if task_rewards:\n                n = len(task_rewards)\n                task_mean = sum(task_rewards) / n\n                if n > 1:\n                    task_std = (sum((r - task_mean)**2 for r in task_rewards) / (n - 1)) ** 0.5\n                    ci_half = 1.96 * task_std / math.sqrt(n)\n                else:\n                    ci_half = 0.0\n                log_data[f\"eval/{task_name}\"] = task_mean\n                log_data[f\"eval/{task_name}_ci\"] = ci_half\n                log_data[f\"eval/{task_name}_n\"] = n\n                ci_strs.append(f\"{task_name}={task_mean:.3f}Β±{ci_half:.3f} (n={n})\")\n        \n        # ── GDPO per-component stats (Change 5) ─────────────────────────────\n        if per_task_all and all(len(v) > 0 for v in per_task_all.values()):\n            try:\n                gdpo_rewards = gdpo_normalize(per_task_all)\n                log_data[\"eval/gdpo_mean\"] = sum(gdpo_rewards) / len(gdpo_rewards)\n                log_data[\"eval/gdpo_std\"] = (sum((r - sum(gdpo_rewards)/len(gdpo_rewards))**2 for r in gdpo_rewards) / len(gdpo_rewards)) ** 0.5\n            except Exception as e:\n                print(f\"  [GDPO] normalization error: {e}\")\n        \n        # ── Dynamic task weight update (Change 6) ───────────────────────────\n        per_task_means = {}\n        for task_name, task_rewards in per_task_all.items():\n            if task_rewards:\n                per_task_means[task_name] = sum(task_rewards) / len(task_rewards)\n        \n        update_task_weights(state.global_step, per_task_means, update_interval=EVAL_STEPS)\n        \n        for task_name, weight in _task_weights.items():\n            log_data[f\"sampler/{task_name}_weight\"] = weight\n        \n        wandb.log(log_data, step=state.global_step)\n\n        status = \"↑ improved\" if improved else f\"↔ no gain ({self.no_improve_count + 1}/{self.patience})\"\n        print(f\"\\n[EvalReward] step={state.global_step} | mean={mean_reward:.4f} | best={self.best_reward:.4f} | {status}\")\n        for cs in ci_strs:\n            print(f\"  {cs}\")\n        print(f\"  Task weights: {', '.join(f'{t}={w:.3f}' for t, w in _task_weights.items())}\")\n\n        if improved:\n            self.best_reward = mean_reward\n            self.best_step = state.global_step\n            self.no_improve_count = 0\n            # ── V4.2: Save best checkpoint explicitly (Change 8) ─────────────\n            best_path = ADAPTER_DIR / \"best_checkpoint\"\n            best_path.mkdir(parents=True, exist_ok=True)\n            model.save_pretrained(str(best_path))\n            tokenizer_local.save_pretrained(str(best_path))\n            print(f\"  βœ“ Best checkpoint saved β†’ {best_path} (reward={mean_reward:.4f})\")\n        else:\n            self.no_improve_count += 1\n            if self.no_improve_count >= self.patience:\n                print(f\"[EarlyStopping] No improvement for {self.patience} evals. Halting.\")\n                control.should_training_stop = True\n        return control\n\n    def _run_eval(self, model, tokenizer_local, args):\n        FastLanguageModel.for_inference(model)\n        rewards = []\n        per_task_summary = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n        per_task_all = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n        \n        # V4.2: Use ALL stratified eval samples (65), not just 15\n        for record in self.eval_records:\n            msgs = record[\"prompt\"]\n            text = tokenizer_local.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n            user_txt = \" \".join(m.get(\"content\", \"\") for m in msgs if m[\"role\"] == \"user\")\n            task = _classify_task_type(user_txt)\n\n            inputs = tokenizer_local(text, return_tensors=\"pt\", truncation=True, max_length=args.max_prompt_length).to(model.device)\n            with torch.no_grad():\n                out = model.generate(\n                    **inputs,\n                    max_new_tokens=EVAL_MAX_TOKENS,\n                    temperature=0.1,\n                    do_sample=True,\n                    repetition_penalty=1.0,\n                )\n            resp = tokenizer_local.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n            r = self.reward_fn([resp], [text])[0]\n            rewards.append(r)\n            if task in per_task_all:\n                per_task_all[task].append(r)\n                per_task_summary[task].append(r)\n        \n        FastLanguageModel.for_training(model)\n        mean_r = sum(rewards) / len(rewards) if rewards else 0.0\n        return mean_r, per_task_summary, per_task_all\n\n\n# ── Training ────────────────────────────────────────────────────────────────\nFastLanguageModel.for_training(model)\n\ngrpo_config = GRPOConfig(\n    output_dir=str(CHECKPOINT_DIR),\n    num_generations=NUM_GENERATIONS,\n    scale_rewards=SCALE_REWARDS,\n    max_completion_length=MAX_COMPLETION_LENGTH,\n    max_steps=MAX_STEPS,                  # V4.2: 1,500\n    temperature=TEMPERATURE,\n    beta=BETA,\n    num_train_epochs=1,\n    per_device_train_batch_size=BATCH_SIZE,\n    gradient_accumulation_steps=GRAD_ACCUM,\n    learning_rate=LEARNING_RATE,\n    lr_scheduler_type=LR_SCHEDULER_TYPE,\n    warmup_ratio=WARMUP_RATIO,\n    fp16=False,\n    bf16=True,\n    logging_steps=1,\n    save_steps=SAVE_STEPS,               # V4.2: 100\n    save_total_limit=5,\n    save_only_model=True,\n    report_to=\"wandb\",\n    max_prompt_length=MAX_SEQ_LENGTH // 2,\n    seed=CURRENT_SEED,                    # V4.2: per-seed\n    remove_unused_columns=False,\n    disable_tqdm=True,\n    logging_first_step=True,\n)\n\neval_cb = EvalRewardCallbackV2(\n    eval_records=list(eval_dataset),\n    reward_fn=commerce_reward_fn_raw,  # V4.2: raw rewards for eval (no GDPO/IWU distortion)\n    patience=EARLY_STOPPING_PATIENCE,\n    delta=EARLY_STOPPING_DELTA,\n)\n\ntrainer = UnslothGRPOTrainer(\n    model=model,\n    reward_funcs=commerce_reward_fn,\n    args=grpo_config,\n    train_dataset=train_dataset,\n    processing_class=tokenizer,\n    callbacks=[eval_cb],\n)\n\nt_start = time.time()\nresult = trainer.train()\nelapsed = time.time() - t_start\n\nwandb.log({\n    \"train/final_loss\": result.training_loss,\n    \"train/duration_hours\": elapsed / 3600,\n    \"train/total_steps\": result.global_step,\n    \"eval/best_reward_final\": eval_cb.best_reward,\n    \"eval/best_step\": eval_cb.best_step,\n    \"final/task_weights\": _task_weights,\n})\nwandb.finish()\n\nprint(f\"\\n{'='*60}\")\nprint(f\"V4.2 Training Complete (seed={CURRENT_SEED})\")\nprint(f\"  Loss:        {result.training_loss:.4f}\")\nprint(f\"  Steps:       {result.global_step}\")\nprint(f\"  Duration:    {elapsed/3600:.1f}h\")\nprint(f\"  Best eval:   {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\nprint(f\"  Final task weights: {_task_weights}\")\nprint(f\"{'='*60}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 14: Post-Training Validation (65 Stratified Samples)\n\n**V4.2:** Full stratified eval with per-task 95% CIs.\n\nReports `mean Β± 1.96 Γ— std/√n` for each task.\n\n**The four questions V4.2 must answer:**\n1. Does SQL reward improve with the new reward function?\n2. Is the insights regression noise or forgetting?\n3. Does multi-epoch training push eval above 0.70?\n4. Are results reproducible across seeds?"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "FastLanguageModel.for_inference(model)\n\nval_samples = list(eval_dataset)  # All 65 stratified samples\nval_results = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n\nprint(f\"Post-training validation on {len(val_samples)} stratified samples (seed={CURRENT_SEED})\")\nprint(\"-\" * 80)\n\nfor i, record in enumerate(val_samples):\n    msgs = record[\"prompt\"]\n    user_text = \" \".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n    task = _classify_task_type(user_text)\n\n    text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n    inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n    with torch.no_grad():\n        out = model.generate(\n            **inputs,\n            max_new_tokens=MAX_COMPLETION_LENGTH,\n            temperature=0.1,\n            do_sample=True,\n            repetition_penalty=1.0,\n        )\n    resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n    r = commerce_reward_fn_raw([resp], [text])[0]  # Raw rewards for reporting\n    val_results[task].append(r)\n    if i < 10 or r < 0.2:  # Print first 10 and all low-scoring\n        print(f\"  [{task:12s}] reward={r:.3f} | {strip_think(resp)[:80]}\")\n\n# ── Results with 95% CIs ────────────────────────────────────────────────────\nprint(f\"\\n{'='*80}\")\nprint(f\"VALIDATION RESULTS β€” V4.2 Seed {CURRENT_SEED}\")\nprint(f\"{'='*80}\")\nprint(f\"{'Task':15s} {'Mean':>8s} {'Β± 95% CI':>10s} {'Min':>6s} {'Max':>6s} {'N':>4s}\")\nprint(\"-\" * 55)\n\noverall = []\nresults_by_seed = {}  # Store for cross-seed comparison\n\nfor task in [\"extraction\", \"sql_qa\", \"insights\", \"push\"]:\n    rewards = val_results[task]\n    overall.extend(rewards)\n    if rewards:\n        n = len(rewards)\n        mean_r = sum(rewards) / n\n        if n > 1:\n            std_r = (sum((r - mean_r)**2 for r in rewards) / (n - 1)) ** 0.5\n            ci_half = 1.96 * std_r / math.sqrt(n)\n        else:\n            std_r = 0.0\n            ci_half = 0.0\n        print(f\"{task:15s} {mean_r:8.3f} {'Β±':>2s}{ci_half:7.3f} {min(rewards):6.3f} {max(rewards):6.3f} {n:4d}\")\n        results_by_seed[task] = {\"mean\": mean_r, \"ci\": ci_half, \"n\": n, \"std\": std_r}\n\nif overall:\n    n_total = len(overall)\n    mean_total = sum(overall) / n_total\n    std_total = (sum((r - mean_total)**2 for r in overall) / (n_total - 1)) ** 0.5\n    ci_total = 1.96 * std_total / math.sqrt(n_total)\n    print(\"-\" * 55)\n    print(f\"{'OVERALL':15s} {mean_total:8.3f} {'Β±':>2s}{ci_total:7.3f} {min(overall):6.3f} {max(overall):6.3f} {n_total:4d}\")\n    results_by_seed[\"overall\"] = {\"mean\": mean_total, \"ci\": ci_total, \"n\": n_total, \"std\": std_total}\n\n# ── Save results for cross-seed comparison ──────────────────────────────────\nresults_file = ADAPTER_DIR / f\"eval_results_seed{CURRENT_SEED}.json\"\nresults_file.parent.mkdir(parents=True, exist_ok=True)\nwith open(results_file, \"w\") as f:\n    json.dump(results_by_seed, f, indent=2)\nprint(f\"\\nβœ“ Results saved to {results_file}\")\n\n# ── V4.2 Decision ───────────────────────────────────────────────────────────\nprint(f\"\\n--- V4.2 Questions ---\")\nsql_mean = results_by_seed.get(\"sql_qa\", {}).get(\"mean\", 0)\ninsights_mean = results_by_seed.get(\"insights\", {}).get(\"mean\", 0)\noverall_mean = results_by_seed.get(\"overall\", {}).get(\"mean\", 0)\n\nprint(f\"Q1 SQL reward: {sql_mean:.3f} ({'improved' if sql_mean > 0.60 else 'still stagnant' if sql_mean < 0.56 else 'modest gain'})\")\nprint(f\"Q2 Insights:   {insights_mean:.3f} ({'stable' if insights_mean > 0.70 else 'regressed' if insights_mean < 0.60 else 'mixed'})\")\nprint(f\"Q3 Overall:    {overall_mean:.3f} ({'above 0.70 target' if overall_mean > 0.70 else 'below target'})\")\nprint(f\"Q4 Seeds:      Seed {CURRENT_SEED} done. Run seeds {[s for s in SEEDS if s != CURRENT_SEED]} next.\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 15: Save Adapter\n\n**V4.2:** Saves from `best_checkpoint/` (peak eval reward) instead of last training step."
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# V4.2: Save the best checkpoint, not the last training step\nbest_checkpoint_path = ADAPTER_DIR / \"best_checkpoint\"\n\nif best_checkpoint_path.exists():\n    print(f\"βœ“ Best checkpoint exists at {best_checkpoint_path}\")\n    print(f\"  Best eval reward: {eval_cb.best_reward:.4f} (step {eval_cb.best_step})\")\n    # Copy best checkpoint to main adapter dir for easy loading\n    import shutil\n    final_path = ADAPTER_DIR / \"final\"\n    if final_path.exists():\n        shutil.rmtree(final_path)\n    shutil.copytree(str(best_checkpoint_path), str(final_path))\n    print(f\"  β†’ Copied to {final_path}\")\nelse:\n    print(\"⚠️ No best_checkpoint found. Saving current model state.\")\n    ADAPTER_DIR.mkdir(parents=True, exist_ok=True)\n    model.save_pretrained(str(ADAPTER_DIR / \"final\"))\n    tokenizer.save_pretrained(str(ADAPTER_DIR / \"final\"))\n\nprint(f\"\\nβœ“ Adapter saved for seed {CURRENT_SEED}\")\nprint(f\"  Location: {ADAPTER_DIR / 'final'}\")\nprint(f\"  Best eval: {eval_cb.best_reward:.4f} at step {eval_cb.best_step}\")"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "---\n\n## Cell 16: Results Table Generation (Change 7)\n\n**Run this after ALL 3 seeds are complete.**\n\nReads `eval_results_seedN.json` from each seed directory and produces the cross-seed\nresults table with mean Β± 95% CI.\n\n```\n| Task         | Seed 42 | Seed 123 | Seed 456 | Mean Β± 95% CI |\n|---|---|---|---|---|\n| Extraction   | ...     | ...      | ...      | X.XX Β± 0.0X   |\n| SQL Q&A      | ...     | ...      | ...      | X.XX Β± 0.0X   |\n| Insights     | ...     | ...      | ...      | X.XX Β± 0.0X   |\n| Push         | ...     | ...      | ...      | X.XX Β± 0.0X   |\n| **Mean**     | ...     | ...      | ...      | **X.XX Β± 0.0X** |\n```"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# ══════════════════════════════════════════════════════════════════════════════\n# V4.2: Cross-Seed Results Table (Change 7)\n# Run this AFTER completing all 3 seeds (42, 123, 456)\n# ══════════════════════════════════════════════════════════════════════════════\n\nresults_by_seed = {}\nmissing_seeds = []\n\nfor seed in SEEDS:\n    seed_dir = MODELS_DIR / f\"tucano2-0.5B-instruct-grpo-v4.2-seed{seed}\"\n    results_file = seed_dir / f\"eval_results_seed{seed}.json\"\n    if results_file.exists():\n        with open(results_file) as f:\n            results_by_seed[seed] = json.load(f)\n        print(f\"βœ“ Loaded results for seed {seed}\")\n    else:\n        missing_seeds.append(seed)\n        print(f\"⚠️ Missing results for seed {seed} (run the notebook with CURRENT_SEED={seed})\")\n\nif missing_seeds:\n    print(f\"\\n⚠️ Missing seeds: {missing_seeds}\")\n    print(f\"  Complete these runs before generating the final table.\")\n    print(f\"  Change CURRENT_SEED in Cell 3 and re-run Cells 3-15.\")\n\nif len(results_by_seed) >= 2:\n    print(f\"\\n{'='*90}\")\n    print(f\"V4.2 CROSS-SEED RESULTS TABLE\")\n    print(f\"{'='*90}\")\n    \n    tasks = [\"extraction\", \"sql_qa\", \"insights\", \"push\", \"overall\"]\n    \n    # Header\n    header = f\"{'Task':15s}\"\n    for seed in SEEDS:\n        if seed in results_by_seed:\n            header += f\" {'Seed '+str(seed):>10s}\"\n    header += f\" {'Mean Β± 95% CI':>18s}\"\n    print(header)\n    print(\"-\" * len(header))\n    \n    for task in tasks:\n        row = f\"{task:15s}\"\n        seed_means = []\n        for seed in SEEDS:\n            if seed in results_by_seed and task in results_by_seed[seed]:\n                m = results_by_seed[seed][task][\"mean\"]\n                seed_means.append(m)\n                row += f\" {m:10.3f}\"\n            elif seed in results_by_seed:\n                row += f\" {'β€”':>10s}\"\n        \n        if len(seed_means) >= 2:\n            cross_mean = sum(seed_means) / len(seed_means)\n            cross_std = (sum((m - cross_mean)**2 for m in seed_means) / (len(seed_means) - 1)) ** 0.5\n            # With 3 seeds, use t-distribution critical value (t=4.303 for 95% CI, df=2)\n            # But for consistency with the handoff, use Β±std\n            row += f\" {cross_mean:7.3f} Β± {cross_std:.3f}\"\n        elif len(seed_means) == 1:\n            row += f\" {seed_means[0]:7.3f} (1 seed)\"\n        \n        if task == \"overall\":\n            row = f\"**{row.strip()}**\"\n        print(row)\n    \n    print(f\"\\n{'='*90}\")\n    \n    # ── Reproducibility assessment ──────────────────────────────────────────\n    if len(results_by_seed) == 3:\n        overall_means = [results_by_seed[s][\"overall\"][\"mean\"] for s in SEEDS if s in results_by_seed]\n        overall_std = (sum((m - sum(overall_means)/len(overall_means))**2 for m in overall_means) / (len(overall_means) - 1)) ** 0.5\n        print(f\"\\nReproducibility: overall std = {overall_std:.4f}\")\n        if overall_std < 0.03:\n            print(f\"  βœ… Robust (std < 0.03): results are reproducible across seeds\")\n        elif overall_std < 0.05:\n            print(f\"  ⚠️ Moderate (0.03 < std < 0.05): some initialization sensitivity\")\n        else:\n            print(f\"  ❌ High variance (std > 0.05): significant initialization sensitivity\")\n\nelse:\n    print(f\"\\nNeed at least 2 seeds to generate comparison table.\")\n    print(f\"Current seed ({CURRENT_SEED}) results:\")\n    if CURRENT_SEED in results_by_seed:\n        for task, data in results_by_seed[CURRENT_SEED].items():\n            print(f\"  {task}: {data['mean']:.3f} Β± {data.get('ci', 0):.3f}\")"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}