rtferraz commited on
Commit
63b1c86
·
verified ·
1 Parent(s): 41eb15f

fix(classifier): reorder _classify_task_type — insights before push to prevent reengajamento misclassification

Browse files
Files changed (1) hide show
  1. notebooks/v4_2_instruct_grpo.ipynb +12 -4
notebooks/v4_2_instruct_grpo.ipynb CHANGED
@@ -81,7 +81,7 @@
81
  "cell_type": "markdown",
82
  "metadata": {},
83
  "source": [
84
- "---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected — requires ≥3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) — ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch → we normalize per-component before returning\n- No trainer modification needed — normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4× more distinct advantage groups (GDPO §3.1)\n\n### Dynamic Task Weights (Change 6) — ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal → larger GRPO advantages → more gradient\n- MT-GRPO IWU §3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamações, taxa, distribuição, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` — rejects floats from PT decimal normalization"
85
  ]
86
  },
87
  {
@@ -103,13 +103,21 @@
103
  "\n",
104
  "\n",
105
  "def _classify_task_type(prompt_text: str) -> str:\n",
 
 
 
 
 
106
  " p = prompt_text.lower()\n",
107
- " if \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n",
 
 
 
 
108
  " return \"extraction\"\n",
 
109
  " elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
110
  " return \"push\"\n",
111
- " elif \"perfil do cliente\" in p or \"retenção\" in p or \"análise\" in p or \"insight\" in p:\n",
112
- " return \"insights\"\n",
113
  " else:\n",
114
  " return \"sql_qa\"\n",
115
  "\n",
 
81
  "cell_type": "markdown",
82
  "metadata": {},
83
  "source": [
84
+ "---\n\n## Cell 7: Reward Functions V2\n\n**V4.2 changes (Change 3 + Change 5):**\n\n### SQL Reward Overhaul (Change 3)\n- **Tier 1 (0.30):** SQL structure detected — requires ≥3 SQL keywords (SELECT, FROM, WHERE, etc.)\n- **Tier 2 (0.25):** Answer has BOTH query AND explanation (not just domain vocabulary)\n- **Tier 3 (0.25):** Numerical specificity (concrete data in answer)\n- **Tier 4 (0.20):** Portuguese business domain coherence\n\n### GDPO Per-Component Normalization (Change 5) — ACTIVE IN TRAINING\n- `commerce_reward_fn` applies per-task z-score normalization INSIDE the reward call\n- TRL 0.24.0 calls reward_fn with the full batch → we normalize per-component before returning\n- No trainer modification needed — normalized rewards flow through standard GRPO advantage computation\n- Preserves ~4× more distinct advantage groups (GDPO §3.1)\n\n### Dynamic Task Weights (Change 6) — ACTIVE IN TRAINING\n- `_task_weights` dict tracks per-task weights, updated by `update_task_weights()` in eval callback\n- Weights are applied as multiplicative scaling INSIDE `commerce_reward_fn` after GDPO normalization\n- Effect: stagnating tasks (e.g. SQL) get amplified reward signal → larger GRPO advantages → more gradient\n- MT-GRPO IWU §3.2: prevents easy-task collapse without requiring custom sampling\n\n### V4.2.1 Fixes (Cell 8 Audit)\n- **Push reward:** Steep length penalty (hard 0 above 200 chars) + formal email penalty (-0.20 for \"Prezado\"/\"Atenciosamente\")\n- **SQL reward Tier 4:** Expanded domain word list (+20 words: compradores, sentimentos, reclamações, taxa, distribuição, etc.)\n- **Extraction reward:** `sentiment_score` validator requires `isinstance(v, int) and not isinstance(v, bool)` — rejects floats from PT decimal normalization\n- **Task classifier:** Reordered `_classify_task_type` — insights checked before push to prevent 'reengajamento' misclassification"
85
  ]
86
  },
87
  {
 
103
  "\n",
104
  "\n",
105
  "def _classify_task_type(prompt_text: str) -> str:\n",
106
+ " \"\"\"V4.2.1: reordered — insights before push to prevent misclassification.\n",
107
+ " \n",
108
+ " \"notificação de reengajamento\" in a customer profile context is insights,\n",
109
+ " not push. Check insights keywords first.\n",
110
+ " \"\"\"\n",
111
  " p = prompt_text.lower()\n",
112
+ " # 1. Insights FIRST customer profile questions mentioning reengagement are insights\n",
113
+ " if \"perfil do cliente\" in p or \"retenção\" in p or \"análise\" in p or \"insight\" in p:\n",
114
+ " return \"insights\"\n",
115
+ " # 2. Extraction\n",
116
+ " elif \"retorne um objeto json\" in p or \"extraia dados\" in p or \"json\" in p:\n",
117
  " return \"extraction\"\n",
118
+ " # 3. Push — only after insights is ruled out\n",
119
  " elif \"notificação push\" in p or \"notificação de reengajamento\" in p:\n",
120
  " return \"push\"\n",
 
 
121
  " else:\n",
122
  " return \"sql_qa\"\n",
123
  "\n",