fix(probe): use TRL 0.24.0 log keys — rewards/commerce_reward_fn/mean, grad_norm (not train/ prefix)
Browse files
notebooks/v4_2_instruct_grpo.ipynb
CHANGED
|
@@ -631,7 +631,90 @@
|
|
| 631 |
"execution_count": null,
|
| 632 |
"metadata": {},
|
| 633 |
"outputs": [],
|
| 634 |
-
"source":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
},
|
| 636 |
{
|
| 637 |
"cell_type": "markdown",
|
|
|
|
| 631 |
"execution_count": null,
|
| 632 |
"metadata": {},
|
| 633 |
"outputs": [],
|
| 634 |
+
"source": [
|
| 635 |
+
"FastLanguageModel.for_training(model)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
"probe_config = GRPOConfig(\n",
|
| 638 |
+
" output_dir=str(CHECKPOINT_DIR / \"probe\"),\n",
|
| 639 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 640 |
+
" scale_rewards=SCALE_REWARDS,\n",
|
| 641 |
+
" max_completion_length=MAX_COMPLETION_LENGTH,\n",
|
| 642 |
+
" max_steps=10,\n",
|
| 643 |
+
" temperature=TEMPERATURE,\n",
|
| 644 |
+
" beta=BETA,\n",
|
| 645 |
+
" num_train_epochs=1,\n",
|
| 646 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
| 647 |
+
" gradient_accumulation_steps=GRAD_ACCUM,\n",
|
| 648 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 649 |
+
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
|
| 650 |
+
" warmup_ratio=WARMUP_RATIO,\n",
|
| 651 |
+
" fp16=False,\n",
|
| 652 |
+
" bf16=True,\n",
|
| 653 |
+
" logging_steps=1,\n",
|
| 654 |
+
" save_steps=999,\n",
|
| 655 |
+
" report_to=\"none\",\n",
|
| 656 |
+
" max_prompt_length=MAX_SEQ_LENGTH // 2,\n",
|
| 657 |
+
" seed=CURRENT_SEED,\n",
|
| 658 |
+
" remove_unused_columns=False,\n",
|
| 659 |
+
")\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"probe_trainer = UnslothGRPOTrainer(\n",
|
| 662 |
+
" model=model,\n",
|
| 663 |
+
" reward_funcs=commerce_reward_fn,\n",
|
| 664 |
+
" args=probe_config,\n",
|
| 665 |
+
" train_dataset=train_dataset,\n",
|
| 666 |
+
" processing_class=tokenizer,\n",
|
| 667 |
+
")\n",
|
| 668 |
+
"\n",
|
| 669 |
+
"t0 = time.time()\n",
|
| 670 |
+
"result = probe_trainer.train()\n",
|
| 671 |
+
"elapsed = time.time() - t0\n",
|
| 672 |
+
"\n",
|
| 673 |
+
"# ── Extract metrics from log history ─────────────────────────────────────────\n",
|
| 674 |
+
"# V4.2.1: TRL 0.24.0 logs under \"reward\" / \"rewards/commerce_reward_fn/mean\"\n",
|
| 675 |
+
"# and \"grad_norm\" (no \"train/\" prefix in log_history entries).\n",
|
| 676 |
+
"rewards = []\n",
|
| 677 |
+
"reward_stds = []\n",
|
| 678 |
+
"grad_norms = []\n",
|
| 679 |
+
"for entry in probe_trainer.state.log_history:\n",
|
| 680 |
+
" if \"rewards/commerce_reward_fn/mean\" in entry:\n",
|
| 681 |
+
" rewards.append(entry[\"rewards/commerce_reward_fn/mean\"])\n",
|
| 682 |
+
" if \"rewards/commerce_reward_fn/std\" in entry:\n",
|
| 683 |
+
" reward_stds.append(entry[\"rewards/commerce_reward_fn/std\"])\n",
|
| 684 |
+
" if \"grad_norm\" in entry:\n",
|
| 685 |
+
" grad_norms.append(entry[\"grad_norm\"])\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 688 |
+
"print(f\"PROBE RESULTS ({elapsed:.0f}s, {elapsed/10:.0f}s/step)\")\n",
|
| 689 |
+
"print(f\" Rewards: {[f'{r:.3f}' for r in rewards]}\")\n",
|
| 690 |
+
"print(f\" Reward stds: {[f'{s:.3f}' for s in reward_stds]}\")\n",
|
| 691 |
+
"print(f\" Grad norms: {[f'{g:.4f}' for g in grad_norms]}\")\n",
|
| 692 |
+
"print(f\" Train loss: {result.training_loss:.4f}\")\n",
|
| 693 |
+
"print(f\"{'='*60}\")\n",
|
| 694 |
+
"\n",
|
| 695 |
+
"if rewards and max(rewards) > 0:\n",
|
| 696 |
+
" print(\"✓ Model generates scoreable output\")\n",
|
| 697 |
+
"else:\n",
|
| 698 |
+
" print(\"✗ WARNING: All rewards are 0. Check reward functions.\")\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"if grad_norms and max(grad_norms) > 0:\n",
|
| 701 |
+
" print(\"✓ Gradients are flowing\")\n",
|
| 702 |
+
"else:\n",
|
| 703 |
+
" print(\"✗ WARNING: All grad_norms are 0. Check model/LoRA setup.\")\n",
|
| 704 |
+
"\n",
|
| 705 |
+
"if reward_stds and min(reward_stds) > 0:\n",
|
| 706 |
+
" print(\"✓ Batches have reward variance (GRPO has signal)\")\n",
|
| 707 |
+
"elif reward_stds:\n",
|
| 708 |
+
" n_zero = sum(1 for s in reward_stds if s < 1e-6)\n",
|
| 709 |
+
" print(f\"⚠️ WARNING: {n_zero}/{len(reward_stds)} steps had zero reward std. Consider increasing G.\")\n",
|
| 710 |
+
"else:\n",
|
| 711 |
+
" print(\"⚠️ WARNING: No reward_std logged. Check TRL version.\")\n",
|
| 712 |
+
"\n",
|
| 713 |
+
"print(\"\\n→ Proceed to full training (Cell 13)\")\n",
|
| 714 |
+
"\n",
|
| 715 |
+
"del probe_trainer\n",
|
| 716 |
+
"gc.collect(); torch.cuda.empty_cache()"
|
| 717 |
+
]
|
| 718 |
},
|
| 719 |
{
|
| 720 |
"cell_type": "markdown",
|