tucano2-commerce / docs /v3_thinking_control_patch.md
rtferraz's picture
Add v3 thinking control patch - task-aware system prompts + think efficiency reward
0f39df7 verified
|
raw
history blame
11.7 kB
# V3 Patch: 3 Changes for Task-Aware Thinking Control
## Overview
These 3 changes go into the v3 notebook. Each change is a precise cell modification.
**Research basis:**
- OptimalThinkingBench (2508.13141): "Don't overthink" → -23% tokens, +7.7pp accuracy on Qwen3
- Mid-Think (2601.07036): task-specific thinking control in GRPO → +2.6pp AIME, -15% train time
- L1 (2503.04697): token budgets in prompts work when trained with RL reward signal
- User's proven extraction prompt: XML-tagged structure + few-shot + schema enforcement
---
## CHANGE 1: Replace SYSTEM_PT with task-aware system prompts (Cell 3)
### REMOVE this block from Cell 3:
```python
SYSTEM_PT = (
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
)
```
### REPLACE with:
```python
# ══════════════════════════════════════════════════════════════════════════════
# v3: TASK-AWARE SYSTEM PROMPTS
# ══════════════════════════════════════════════════════════════════════════════
# Research basis:
# - OptimalThinkingBench (2508.13141): "Don't overthink" → -23% tokens, +7.7pp accuracy on Qwen3
# - Mid-Think (2601.07036): task-specific thinking control in GRPO → +2.6pp AIME, -15% train time
# - L1 (2503.04697): token budgets in prompts work when trained with RL reward signal
# - User's proven extraction prompt: XML-tagged structure + few-shot + schema enforcement
COMPLAINT_CATEGORIES_STR = ", ".join(sorted(VALID_CATEGORIES))
SYSTEM_EXTRACTION = (
"Você é um motor de extração de dados de e-commerce brasileiro. "
"Retorne APENAS um objeto JSON válido, sem nenhum texto antes ou depois. "
"NÃO USE blocos de código markdown (` `` json). "
"O primeiro caractere da sua resposta deve ser { e o último deve ser }. "
"Campos não mencionados na avaliação devem ser null — nunca invente valores. "
"Sem explicação. Sem comentários. Não pense em excesso."
)
SYSTEM_SQL = (
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\n\n"
"Para consultas e análises de dados: pense brevemente sobre a estrutura necessária, "
"depois apresente a resposta de forma direta com números e dados concretos. "
"Seja conciso no raciocínio. Não pense em excesso."
)
SYSTEM_INSIGHTS = (
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\n\n"
"Para análises estratégicas: raciocine de forma estruturada e concisa, "
"focando nos pontos principais e recomendações acionáveis. "
"Use no máximo 500 tokens para raciocinar antes de responder."
)
SYSTEM_PUSH = (
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro.\n\n"
"Para notificações push: seja direto e criativo. "
"A notificação deve ter no máximo 120 caracteres. "
"Responda diretamente sem pensar em excesso."
)
# Legacy fallback — used only in cells that don't have task context
SYSTEM_PT = (
"Você é um assistente de IA especializado em análise de e-commerce brasileiro. "
"Você compreende avaliações de clientes em português e padrões de comércio brasileiro."
)
def get_system_prompt(task_type: str) -> str:
"""Return task-optimized system prompt."""
return {
"extraction": SYSTEM_EXTRACTION,
"sql_qa": SYSTEM_SQL,
"insights": SYSTEM_INSIGHTS,
"push": SYSTEM_PUSH,
}.get(task_type, SYSTEM_PT)
# ── Think token budgets per task (for reward function) ────────────────────────
# These are soft targets — the reward function nudges, not enforces
THINK_BUDGETS = {
"extraction": 150, # Extraction barely needs thinking — pattern matching
"push": 100, # Push is creative writing, not reasoning
"sql_qa": 400, # SQL benefits from brief query planning
"insights": 800, # Insights need structured multi-step analysis
}
print("✓ v3 Task-aware system prompts defined")
print(f" extraction: '{SYSTEM_EXTRACTION[:60]}...'")
print(f" sql_qa: '{SYSTEM_SQL[:60]}...'")
print(f" insights: '{SYSTEM_INSIGHTS[:60]}...'")
print(f" push: '{SYSTEM_PUSH[:60]}...'")
```
---
## CHANGE 2: Add reward_think_efficiency() to Cell 6 (Reward Functions)
### ADD this function right before `commerce_reward_fn` in Cell 6:
```python
def reward_think_efficiency(completion: str, task_type: str) -> float:
"""
Reward concise thinking, penalize bloated <think> blocks.
v3 NEW — Research basis:
- OptimalThinkingBench (2508.13141): overthinking hurts accuracy on simple tasks
- L1 (2503.04697): token budget rewards teach models to control reasoning length
- Train Long Think Short (2508.08940): triangular length reward around target budget
Returns: -0.05 to +0.1 (small component — nudge, not dominate)
"""
think_match = re.search(r"<think>(.*?)</think>", completion, re.DOTALL)
budget = THINK_BUDGETS.get(task_type, 500)
if not think_match:
# No think block at all
if task_type in ("extraction", "push"):
return 0.1 # Great — these tasks don't need thinking
else:
return 0.0 # Neutral for analytical tasks
think_content = think_match.group(1).strip()
think_chars = len(think_content) # chars as proxy (cheaper than tokenizing)
# Rough conversion: ~4 chars per token for Portuguese
think_tokens_approx = think_chars / 4
if think_tokens_approx <= budget:
# Within budget — reward proportional to how concise
return 0.1
elif think_tokens_approx <= budget * 2:
# Over budget but not catastrophic — linear decay
overshoot = (think_tokens_approx - budget) / budget
return 0.1 * (1.0 - overshoot) # 0.1 → 0.0
else:
# Way over budget — mild penalty
return -0.05
```
### MODIFY `commerce_reward_fn` dispatch block:
**Current code (REMOVE):**
```python
if task == "extraction":
rewards.append(reward_extraction(comp_text))
elif task == "sql_qa":
rewards.append(reward_sql_qa(comp_text))
elif task == "insights":
rewards.append(reward_insights(comp_text))
elif task == "push":
rewards.append(reward_push(comp_text))
else:
r = 0.15 if has_think_block(comp_text) else 0.0
r += 0.2 if comp_text.strip() else 0.0
rewards.append(r)
```
**New code (REPLACE WITH):**
```python
if task == "extraction":
task_r = reward_extraction(comp_text)
elif task == "sql_qa":
task_r = reward_sql_qa(comp_text)
elif task == "insights":
task_r = reward_insights(comp_text)
elif task == "push":
task_r = reward_push(comp_text)
else:
task_r = 0.15 if has_think_block(comp_text) else 0.0
task_r += 0.2 if comp_text.strip() else 0.0
# v3: Think efficiency bonus/penalty (small weight — nudge, not dominate)
think_r = reward_think_efficiency(comp_text, task)
rewards.append(task_r + think_r)
```
---
## CHANGE 3: Wire system prompts into data preparation and eval
### Cell 7 (Calibration) — add helper + use in loop:
Add this helper function after loading `by_type`:
```python
def inject_task_system_prompt(msgs, task_type):
"""Replace generic system prompt with task-specific one."""
new_msgs = []
system_prompt = get_system_prompt(task_type)
has_system = False
for m in msgs:
if m["role"] == "system":
new_msgs.append({"role": "system", "content": system_prompt})
has_system = True
else:
new_msgs.append(m)
if not has_system:
new_msgs.insert(0, {"role": "system", "content": system_prompt})
return new_msgs
```
Then in the calibration loop, inject the task-aware prompt before template application:
```python
for i, msgs in enumerate(cal_samples):
# Determine task type from user content
user_text = " ".join(m["content"] for m in msgs if m["role"] == "user")
task = _classify_task_type(user_text)
# v3: Inject task-aware system prompt
msgs = inject_task_system_prompt(msgs, task)
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
# ... rest of loop unchanged
```
### Cell 8 (Dataset Preparation) — inject into train/eval records:
In `prepare_grpo_datasets_v3`, after building train_records and eval_records (before creating HF Datasets), add:
```python
# v3: Inject task-aware system prompts into each training record
for i, record in enumerate(train_records):
user_text = " ".join(m["content"] for m in record if m["role"] == "user")
task = _classify_task_type(user_text)
train_records[i] = inject_task_system_prompt(record, task)
# Same for eval records
for i, record in enumerate(eval_records):
user_text = " ".join(m["content"] for m in record if m["role"] == "user")
task = _classify_task_type(user_text)
eval_records[i] = inject_task_system_prompt(record, task)
print(f" ✓ Task-aware system prompts injected")
```
### Cell 11 (EvalRewardCallback) — no change needed:
System prompts were injected in Cell 8, so eval data already has the right prompts.
### Cell 13 (Validation) — use task-aware selection:
Replace:
```python
system_msg = {"role": "system", "content": SYSTEM_PT}
```
With task-aware selection inside the loop:
```python
# REMOVE the fixed system_msg line above the loop
# Inside the loop, before generating:
task = _classify_task_type(prompt["content"])
system_msg = {"role": "system", "content": get_system_prompt(task)}
messages = [system_msg, prompt]
```
---
## Summary
| Cell | What changes | Lines affected |
|------|-------------|---------------|
| Cell 3 | Replace `SYSTEM_PT` with 4 task prompts + `get_system_prompt()` + `THINK_BUDGETS` | ~50 lines added |
| Cell 6 | Add `reward_think_efficiency()`, modify `commerce_reward_fn` dispatch | ~35 lines added, ~10 modified |
| Cell 7 | Add `inject_task_system_prompt()`, use in calibration loop | ~15 lines added |
| Cell 8 | Inject task-aware system prompts into train/eval records | ~10 lines added |
| Cell 13 | Use `get_system_prompt(task)` instead of fixed `SYSTEM_PT` | ~3 lines modified |
## Expected impact
| Task | Current think tokens | Expected after patch | Mechanism |
|------|---------------------|---------------------|-----------|
| Extraction | 2000-3000 (100% ceiling) | ~300-800 (-60-70%) | "Não pense em excesso" + think penalty reward |
| Push | 1000-2000 | ~100-300 (-70-80%) | "Responda diretamente" + think penalty reward |
| SQL Q&A | 1500-2500 | ~400-800 (-50%) | "Seja conciso no raciocínio" + think budget reward |
| Insights | 2000-3200 (ceiling) | ~800-1500 (-30-40%) | "Use no máximo 500 tokens" + higher think budget |