fix(classifier): reorder _classify_task_type — insights before push to prevent reengajamento misclassification
Browse files
notebooks/v4_2_instruct_grpo.ipynb
CHANGED
|
@@ -81,7 +81,7 @@
|
|
| 81 |
"cell_type": "markdown",
|
| 82 |
"metadata": {},
|
| 83 |
"source": [
|
| 84 |
-
"---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected — requires ≥3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) — ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch → we normalize per-component before returning\n- No trainer modification needed — normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4× more distinct advantage groups (GDPO §3.1)\n\n### Dynamic Task Weights (Change 6) — ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal → larger GRPO advantages → more gradient\n- MT-GRPO IWU §3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamações, taxa, distribuição, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` — rejects floats from PT decimal normalization"
|
| 85 |
]
|
| 86 |
},
|
| 87 |
{
|
|
@@ -103,13 +103,21 @@
|
|
| 103 |
"\n",
|
| 104 |
"\n",
|
| 105 |
"def _classify_task_type(prompt_text: str) -> str:\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
" p = prompt_text.lower()\n",
|
| 107 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
" return \"extraction\"\n",
|
|
|
|
| 109 |
" elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
|
| 110 |
" return \"push\"\n",
|
| 111 |
-
" elif \"perfil do cliente\" in p or \"retenção\" in p or \"análise\" in p or \"insight\" in p:\n",
|
| 112 |
-
" return \"insights\"\n",
|
| 113 |
" else:\n",
|
| 114 |
" return \"sql_qa\"\n",
|
| 115 |
"\n",
|
|
|
|
| 81 |
"cell_type": "markdown",
|
| 82 |
"metadata": {},
|
| 83 |
"source": [
|
| 84 |
+
"---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected — requires ≥3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) — ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch → we normalize per-component before returning\n- No trainer modification needed — normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4× more distinct advantage groups (GDPO §3.1)\n\n### Dynamic Task Weights (Change 6) — ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal → larger GRPO advantages → more gradient\n- MT-GRPO IWU §3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamações, taxa, distribuição, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` — rejects floats from PT decimal normalization\n- **Task classifier:** Reordered `_classify_task_type` — insights checked before push to prevent 'reengajamento' misclassification"
|
| 85 |
]
|
| 86 |
},
|
| 87 |
{
|
|
|
|
| 103 |
"\n",
|
| 104 |
"\n",
|
| 105 |
"def _classify_task_type(prompt_text: str) -> str:\n",
|
| 106 |
+
" \"\"\"V4.2.1: reordered — insights before push to prevent misclassification.\n",
|
| 107 |
+
" \n",
|
| 108 |
+
" \"notificação de reengajamento\" in a customer profile context is insights,\n",
|
| 109 |
+
" not push. Check insights keywords first.\n",
|
| 110 |
+
" \"\"\"\n",
|
| 111 |
" p = prompt_text.lower()\n",
|
| 112 |
+
" # 1. Insights FIRST — customer profile questions mentioning reengagement are insights\n",
|
| 113 |
+
" if \"perfil do cliente\" in p or \"retenção\" in p or \"análise\" in p or \"insight\" in p:\n",
|
| 114 |
+
" return \"insights\"\n",
|
| 115 |
+
" # 2. Extraction\n",
|
| 116 |
+
" elif \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n",
|
| 117 |
" return \"extraction\"\n",
|
| 118 |
+
" # 3. Push — only after insights is ruled out\n",
|
| 119 |
" elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
|
| 120 |
" return \"push\"\n",
|
|
|
|
|
|
|
| 121 |
" else:\n",
|
| 122 |
" return \"sql_qa\"\n",
|
| 123 |
"\n",
|