feat(rewards): add sentiment mismatch penalty to prevent extraction reward hacking
Browse files- Modified reward_extraction to accept prompt_text and cross-check predicted sentiment against nota (review rating)
- Penalize -0.20 when nota β€ 2 (negative) but sentiment is "positive", or nota β₯ 4 (positive) but sentiment is "negative"
- Reduced task weight cap from 0.60 to 0.50 for more conservative weight updates
- Updated both reward_extraction calls to pass original prompt text
- Reformatted audit cell for better readability
Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
notebooks/v4_2_instruct_grpo.ipynb
CHANGED
|
@@ -173,8 +173,8 @@
|
|
| 173 |
" return None\n",
|
| 174 |
"\n",
|
| 175 |
"\n",
|
| 176 |
-
"def reward_extraction(completion: str) -> float:\n",
|
| 177 |
-
" \"\"\"Continuous reward for extraction tasks (max 1.0).
|
| 178 |
" answer = strip_think(completion)\n",
|
| 179 |
" data = _extract_json(answer)\n",
|
| 180 |
"\n",
|
|
@@ -216,7 +216,19 @@
|
|
| 216 |
" if checks_total > 0:\n",
|
| 217 |
" score += 0.4 * (checks_passed / checks_total)\n",
|
| 218 |
"\n",
|
| 219 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
"\n",
|
| 221 |
"\n",
|
| 222 |
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
|
@@ -408,7 +420,7 @@
|
|
| 408 |
" if len(_task_reward_history[task]) >= 2:\n",
|
| 409 |
" improvement = _task_reward_history[task][-1] - _task_reward_history[task][-2]\n",
|
| 410 |
" if improvement < 0.01: # stagnating\n",
|
| 411 |
-
" _task_weights[task] = min(0.
|
| 412 |
" elif improvement > 0.05: # improving fast\n",
|
| 413 |
" _task_weights[task] = max(0.10, _task_weights[task] * 0.85)\n",
|
| 414 |
" \n",
|
|
@@ -477,7 +489,7 @@
|
|
| 477 |
" task_labels[i] = task\n",
|
| 478 |
"\n",
|
| 479 |
" if task == \"extraction\":\n",
|
| 480 |
-
" raw_rewards[i] = reward_extraction(comp_text)\n",
|
| 481 |
" elif task == \"sql_qa\":\n",
|
| 482 |
" raw_rewards[i] = reward_sql_qa(comp_text)\n",
|
| 483 |
" elif task == \"insights\":\n",
|
|
@@ -553,7 +565,7 @@
|
|
| 553 |
" task = _classify_task_type(prompt_text)\n",
|
| 554 |
"\n",
|
| 555 |
" if task == \"extraction\":\n",
|
| 556 |
-
" rewards.append(reward_extraction(comp_text))\n",
|
| 557 |
" elif task == \"sql_qa\":\n",
|
| 558 |
" rewards.append(reward_sql_qa(comp_text))\n",
|
| 559 |
" elif task == \"insights\":\n",
|
|
@@ -583,7 +595,117 @@
|
|
| 583 |
"execution_count": null,
|
| 584 |
"metadata": {},
|
| 585 |
"outputs": [],
|
| 586 |
-
"source":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
},
|
| 588 |
{
|
| 589 |
"cell_type": "markdown",
|
|
|
|
| 173 |
" return None\n",
|
| 174 |
"\n",
|
| 175 |
"\n",
|
| 176 |
+
"def reward_extraction(completion: str, prompt_text: str = \"\") -> float:\n",
|
| 177 |
+
" \"\"\"Continuous reward for extraction tasks (max 1.0).\"\"\"\n",
|
| 178 |
" answer = strip_think(completion)\n",
|
| 179 |
" data = _extract_json(answer)\n",
|
| 180 |
"\n",
|
|
|
|
| 216 |
" if checks_total > 0:\n",
|
| 217 |
" score += 0.4 * (checks_passed / checks_total)\n",
|
| 218 |
"\n",
|
| 219 |
+
" # nota=1-2 on a 5-star scale β negative review; nota=4-5 β positive.\n",
|
| 220 |
+
" # Penalize clear sentiment mismatches to break reward hacking.\n",
|
| 221 |
+
" import re as _re\n",
|
| 222 |
+
" nota_match = _re.search(r\"nota=(\\d)/5\", prompt_text)\n",
|
| 223 |
+
" if nota_match and \"sentiment\" in data:\n",
|
| 224 |
+
" nota = int(nota_match.group(1))\n",
|
| 225 |
+
" sentiment = data.get(\"sentiment\", \"\")\n",
|
| 226 |
+
" if nota <= 2 and sentiment == \"positive\":\n",
|
| 227 |
+
" score -= 0.20\n",
|
| 228 |
+
" elif nota >= 4 and sentiment == \"negative\":\n",
|
| 229 |
+
" score -= 0.20\n",
|
| 230 |
+
"\n",
|
| 231 |
+
" return max(0.0, min(score, 1.0))\n",
|
| 232 |
"\n",
|
| 233 |
"\n",
|
| 234 |
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
|
|
|
| 420 |
" if len(_task_reward_history[task]) >= 2:\n",
|
| 421 |
" improvement = _task_reward_history[task][-1] - _task_reward_history[task][-2]\n",
|
| 422 |
" if improvement < 0.01: # stagnating\n",
|
| 423 |
+
" _task_weights[task] = min(0.50, _task_weights[task] * 1.3)\n",
|
| 424 |
" elif improvement > 0.05: # improving fast\n",
|
| 425 |
" _task_weights[task] = max(0.10, _task_weights[task] * 0.85)\n",
|
| 426 |
" \n",
|
|
|
|
| 489 |
" task_labels[i] = task\n",
|
| 490 |
"\n",
|
| 491 |
" if task == \"extraction\":\n",
|
| 492 |
+
" raw_rewards[i] = reward_extraction(comp_text, prompt_text)\n",
|
| 493 |
" elif task == \"sql_qa\":\n",
|
| 494 |
" raw_rewards[i] = reward_sql_qa(comp_text)\n",
|
| 495 |
" elif task == \"insights\":\n",
|
|
|
|
| 565 |
" task = _classify_task_type(prompt_text)\n",
|
| 566 |
"\n",
|
| 567 |
" if task == \"extraction\":\n",
|
| 568 |
+
" rewards.append(reward_extraction(comp_text, prompt_text))\n",
|
| 569 |
" elif task == \"sql_qa\":\n",
|
| 570 |
" rewards.append(reward_sql_qa(comp_text))\n",
|
| 571 |
" elif task == \"insights\":\n",
|
|
|
|
| 595 |
"execution_count": null,
|
| 596 |
"metadata": {},
|
| 597 |
"outputs": [],
|
| 598 |
+
"source": [
|
| 599 |
+
"from scipy.stats import spearmanr\n",
|
| 600 |
+
"\n",
|
| 601 |
+
"AUDIT_PROMPTS_PER_TASK = 5\n",
|
| 602 |
+
"\n",
|
| 603 |
+
"# ββ Collect audit prompts (5 per task) βββββββββββββββββββββββββββββββββββββββ\n",
|
| 604 |
+
"audit_by_type = {\"extraction\": [], \"sql_qa\": [], \"insights\": [], \"push\": []}\n",
|
| 605 |
+
"with open(TRAIN_FILE) as f:\n",
|
| 606 |
+
" for line in f:\n",
|
| 607 |
+
" row = json.loads(line)\n",
|
| 608 |
+
" convs = row[\"conversations\"]\n",
|
| 609 |
+
" prompt_msgs = [m for m in convs if m[\"role\"] in (\"system\", \"user\")]\n",
|
| 610 |
+
" if not prompt_msgs:\n",
|
| 611 |
+
" continue\n",
|
| 612 |
+
" user_text = \" \".join(m[\"content\"] for m in prompt_msgs if m[\"role\"] == \"user\")\n",
|
| 613 |
+
" task = _classify_task_type(user_text)\n",
|
| 614 |
+
" if len(audit_by_type[task]) < AUDIT_PROMPTS_PER_TASK:\n",
|
| 615 |
+
" audit_by_type[task].append(prompt_msgs)\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"print(f\"Audit prompts collected: {', '.join(f'{k}={len(v)}' for k, v in audit_by_type.items())}\")\n",
|
| 618 |
+
"\n",
|
| 619 |
+
"# ββ Generate completions and score automatically βββββββββββββββββββββββββββββ\n",
|
| 620 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 621 |
+
"\n",
|
| 622 |
+
"audit_auto_scores = []\n",
|
| 623 |
+
"audit_tasks = []\n",
|
| 624 |
+
"audit_completions = []\n",
|
| 625 |
+
"\n",
|
| 626 |
+
"audit_prompts_text = [] # store original user message for display\n",
|
| 627 |
+
"\n",
|
| 628 |
+
"for task_type in [\"extraction\", \"sql_qa\", \"insights\", \"push\"]:\n",
|
| 629 |
+
" for msgs in audit_by_type[task_type]:\n",
|
| 630 |
+
" # Extract original user message BEFORE injecting system prompt\n",
|
| 631 |
+
" user_content = \"\\n\".join(m[\"content\"] for m in msgs if m[\"role\"] == \"user\")\n",
|
| 632 |
+
" audit_prompts_text.append(user_content)\n",
|
| 633 |
+
" \n",
|
| 634 |
+
" msgs = inject_task_system_prompt(msgs, task_type)\n",
|
| 635 |
+
" text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
|
| 636 |
+
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
|
| 637 |
+
" with torch.no_grad():\n",
|
| 638 |
+
" out = model.generate(\n",
|
| 639 |
+
" **inputs,\n",
|
| 640 |
+
" max_new_tokens=MAX_COMPLETION_LENGTH,\n",
|
| 641 |
+
" temperature=0.1, # near-deterministic for audit\n",
|
| 642 |
+
" do_sample=True,\n",
|
| 643 |
+
" repetition_penalty=1.0,\n",
|
| 644 |
+
" )\n",
|
| 645 |
+
" resp = tokenizer.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 646 |
+
" r = commerce_reward_fn_raw([resp], [text])[0] # Raw rewards for audit (not GDPO-normalized)\n",
|
| 647 |
+
" audit_auto_scores.append(r)\n",
|
| 648 |
+
" audit_tasks.append(task_type)\n",
|
| 649 |
+
" audit_completions.append(resp)\n",
|
| 650 |
+
"\n",
|
| 651 |
+
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 652 |
+
"# INTERACTIVE REWARD AUDIT\n",
|
| 653 |
+
"# Shows each completion in FULL (no truncation), prompts for a 0-10 score.\n",
|
| 654 |
+
"# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 655 |
+
"\n",
|
| 656 |
+
"print(f\"\\n{'='*80}\")\n",
|
| 657 |
+
"print(\"REWARD FUNCTION AUDIT β 20 Completions (interactive scoring)\")\n",
|
| 658 |
+
"print(\"Score each completion 0-10: 0=garbage, 5=acceptable, 10=perfect\")\n",
|
| 659 |
+
"print(f\"{'='*80}\")\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"audit_human_scores = []\n",
|
| 662 |
+
"\n",
|
| 663 |
+
"for i, (task, auto_r, comp, prompt_txt) in enumerate(zip(audit_tasks, audit_auto_scores, audit_completions, audit_prompts_text)):\n",
|
| 664 |
+
" answer = strip_think(comp) # full completion, no truncation\n",
|
| 665 |
+
" print(f\"\\n{'β'*80}\")\n",
|
| 666 |
+
" print(f\" Sample {i+1}/{len(audit_auto_scores)} [{task}] auto_reward={auto_r:.3f}\")\n",
|
| 667 |
+
" print(f\"{'β'*80}\")\n",
|
| 668 |
+
" print(f\"\\nINPUT REVIEW:\\n{prompt_txt}\\n\")\n",
|
| 669 |
+
" print(f\"MODEL OUTPUT:\\n{answer}\")\n",
|
| 670 |
+
" print()\n",
|
| 671 |
+
" while True:\n",
|
| 672 |
+
" try:\n",
|
| 673 |
+
" score = float(input(f\" Your score (0-10): \"))\n",
|
| 674 |
+
" if 0 <= score <= 10:\n",
|
| 675 |
+
" break\n",
|
| 676 |
+
" print(\" β οΈ Score must be between 0 and 10\")\n",
|
| 677 |
+
" except (ValueError, EOFError):\n",
|
| 678 |
+
" print(\" β οΈ Enter a number between 0 and 10\")\n",
|
| 679 |
+
" audit_human_scores.append(score)\n",
|
| 680 |
+
" print(f\" β Recorded: human={score:.0f}, auto={auto_r:.3f}\")\n",
|
| 681 |
+
"\n",
|
| 682 |
+
"# ββ Compute Spearman Ο βββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 683 |
+
"human_normalized = [s / 10.0 for s in audit_human_scores]\n",
|
| 684 |
+
"rho, p_value = spearmanr(human_normalized, audit_auto_scores)\n",
|
| 685 |
+
"\n",
|
| 686 |
+
"print(f\"\\n{'='*80}\")\n",
|
| 687 |
+
"print(f\"AUDIT RESULTS\")\n",
|
| 688 |
+
"print(f\"{'='*80}\")\n",
|
| 689 |
+
"print(f\" Spearman Ο = {rho:.3f} (p = {p_value:.4f})\")\n",
|
| 690 |
+
"print()\n",
|
| 691 |
+
"print(f\" {'#':>3s} {'Task':12s} {'Human':>6s} {'Auto':>6s} {'Ξ':>6s}\")\n",
|
| 692 |
+
"print(f\" {'β'*40}\")\n",
|
| 693 |
+
"for i, (task, h, a) in enumerate(zip(audit_tasks, human_normalized, audit_auto_scores)):\n",
|
| 694 |
+
" delta = abs(h - a)\n",
|
| 695 |
+
" flag = \" β οΈ\" if delta > 0.3 else \"\"\n",
|
| 696 |
+
" print(f\" {i+1:3d} {task:12s} {h:6.2f} {a:6.3f} {delta:6.3f}{flag}\")\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"if rho > 0.70:\n",
|
| 699 |
+
" print(f\"\\n β
PASS: Ο={rho:.3f} > 0.70 β reward function is calibrated\")\n",
|
| 700 |
+
"else:\n",
|
| 701 |
+
" print(f\"\\n β FAIL: Ο={rho:.3f} < 0.70 β reward function is miscalibrated\")\n",
|
| 702 |
+
" print(\" β Investigate samples marked β οΈ before training. Check:\")\n",
|
| 703 |
+
" print(\" 1. Is the JSON parser handling all output formats?\")\n",
|
| 704 |
+
" print(\" 2. Are SQL reward tiers appropriate for this model's output style?\")\n",
|
| 705 |
+
" print(\" 3. Are insights/push length penalties calibrated?\")\n",
|
| 706 |
+
"\n",
|
| 707 |
+
"assert rho > 0.65, f\"Reward function miscalibrated (Ο={rho:.3f} < 0.65). Fix before training.\""
|
| 708 |
+
]
|
| 709 |
},
|
| 710 |
{
|
| 711 |
"cell_type": "markdown",
|