agent-cost-optimizer / aco /classifier.py
narcolepticchicken's picture
Upload aco/classifier.py
a7c7186 verified
raw
history blame
10.3 kB
"""Task Cost Classifier - Module 2.
Classifies incoming tasks by expected cost, risk, model strength needed,
and predicts whether retrieval/verifier is required.
"""
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from .trace_schema import TaskType
from .config import ACOConfig
@dataclass
class TaskPrediction:
task_type: TaskType
expected_cost: float
expected_model_tier: int # 1-5
expected_tools_needed: List[str]
risk_of_failure: float # 0-1
retrieval_required: bool
verifier_required: bool
expected_latency_ms: float
confidence: float
class TaskCostClassifier:
"""Classifies agent tasks into cost/risk categories."""
# Keywords mapped to task types with base cost estimates
KEYWORD_MAP: Dict[str, Tuple[TaskType, float, int]] = {
# quick_answer: low cost, tier 1-2
"what is": (TaskType.QUICK_ANSWER, 0.001, 1),
"define": (TaskType.QUICK_ANSWER, 0.001, 1),
"explain briefly": (TaskType.QUICK_ANSWER, 0.002, 1),
"summarize": (TaskType.QUICK_ANSWER, 0.005, 2),
"short answer": (TaskType.QUICK_ANSWER, 0.001, 1),
# coding: medium-high cost, tier 3-4
"write code": (TaskType.CODING, 0.05, 3),
"fix bug": (TaskType.CODING, 0.08, 4),
"refactor": (TaskType.CODING, 0.03, 3),
"implement": (TaskType.CODING, 0.05, 3),
"test": (TaskType.CODING, 0.04, 3),
"debug": (TaskType.CODING, 0.06, 4),
"python": (TaskType.CODING, 0.03, 3),
"javascript": (TaskType.CODING, 0.03, 3),
"function": (TaskType.CODING, 0.02, 2),
# research: high cost, tier 3-4
"research": (TaskType.RESEARCH, 0.15, 4),
"find sources": (TaskType.RESEARCH, 0.1, 3),
"literature review": (TaskType.RESEARCH, 0.2, 4),
"compare": (TaskType.RESEARCH, 0.08, 3),
"analyze": (TaskType.RESEARCH, 0.1, 3),
"investigate": (TaskType.RESEARCH, 0.12, 4),
# document_drafting: medium cost, tier 3
"draft": (TaskType.DOCUMENT_DRAFTING, 0.05, 3),
"write a document": (TaskType.DOCUMENT_DRAFTING, 0.06, 3),
"proposal": (TaskType.DOCUMENT_DRAFTING, 0.08, 3),
"report": (TaskType.DOCUMENT_DRAFTING, 0.1, 4),
"email": (TaskType.DOCUMENT_DRAFTING, 0.01, 2),
# legal_regulated: high cost, tier 4-5
"contract": (TaskType.LEGAL_REGULATED, 0.15, 5),
"legal": (TaskType.LEGAL_REGULATED, 0.15, 5),
"compliance": (TaskType.LEGAL_REGULATED, 0.12, 5),
"regulatory": (TaskType.LEGAL_REGULATED, 0.12, 5),
"privacy policy": (TaskType.LEGAL_REGULATED, 0.1, 5),
"terms of service": (TaskType.LEGAL_REGULATED, 0.1, 5),
# tool_heavy
"search for": (TaskType.TOOL_HEAVY, 0.05, 3),
"look up": (TaskType.TOOL_HEAVY, 0.03, 2),
"fetch": (TaskType.TOOL_HEAVY, 0.04, 3),
"api": (TaskType.TOOL_HEAVY, 0.06, 3),
"database": (TaskType.TOOL_HEAVY, 0.05, 3),
"scrape": (TaskType.TOOL_HEAVY, 0.04, 3),
# retrieval_heavy
"based on the document": (TaskType.RETRIEVAL_HEAVY, 0.08, 3),
"from my files": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
"rag": (TaskType.RETRIEVAL_HEAVY, 0.06, 3),
"retrieve": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
# long_horizon
"plan": (TaskType.LONG_HORIZON, 0.1, 4),
"project": (TaskType.LONG_HORIZON, 0.15, 4),
"over the next": (TaskType.LONG_HORIZON, 0.1, 4),
"multi-step": (TaskType.LONG_HORIZON, 0.08, 4),
"orchestrate": (TaskType.LONG_HORIZON, 0.12, 4),
}
# Complexity multipliers based on length and structure
COMPLEXITY_PATTERNS = [
(r"\b(AND|and)\b.*\b(AND|and)\b.*\b(AND|and)\b", 1.5), # multiple sub-tasks
(r"\bstep\s+\d+\b", 1.3),
(r"\d+\+\s*(pages|files|functions|tests)", 1.4),
(r"\b(entire|whole|all|every)\b", 1.2),
(r"\b(critical|production|live|deployed)\b", 1.5),
]
def __init__(self, config: Optional[ACOConfig] = None):
self.config = config or ACOConfig()
self.history: List[Dict] = []
def classify(self, user_request: str) -> TaskPrediction:
"""Classify a user request into task type, cost, risk, etc."""
request_lower = user_request.lower()
# Find best matching keywords
matched_types: Dict[TaskType, List[float]] = {}
for keyword, (task_type, base_cost, tier) in self.KEYWORD_MAP.items():
if keyword in request_lower:
matched_types.setdefault(task_type, []).append(base_cost)
# Default to unknown if no match
if not matched_types:
task_type = TaskType.UNKNOWN_AMBIGUOUS
base_cost = 0.05
base_tier = 2
else:
# Pick task type with highest cumulative base cost (most specific)
task_type = max(matched_types.keys(), key=lambda t: sum(matched_types[t]))
base_cost = max(matched_types[task_type])
base_tier = self.KEYWORD_MAP[
max(
(k for k, (tt, _, _) in self.KEYWORD_MAP.items() if tt == task_type),
key=lambda k: base_cost if k in request_lower else 0,
)
][2]
# Apply complexity multipliers
complexity_mult = 1.0
for pattern, mult in self.COMPLEXITY_PATTERNS:
if re.search(pattern, user_request, re.IGNORECASE):
complexity_mult = max(complexity_mult, mult)
# Length factor
word_count = len(request_lower.split())
length_mult = 1.0 + min(word_count / 500, 0.5)
expected_cost = base_cost * complexity_mult * length_mult
expected_tier = min(base_tier + int(complexity_mult > 1.2), 5)
# Determine tool needs
expected_tools = []
if task_type in (TaskType.RESEARCH, TaskType.TOOL_HEAVY, TaskType.RETRIEVAL_HEAVY):
expected_tools = ["search", "retrieve", "fetch"]
elif task_type == TaskType.CODING:
expected_tools = ["code_execution", "linter", "test_runner"]
elif task_type == TaskType.LEGAL_REGULATED:
expected_tools = ["document_retrieval", "compliance_check"]
# Risk estimation
risk = 0.3
if task_type == TaskType.LEGAL_REGULATED:
risk = 0.8
elif task_type == TaskType.LONG_HORIZON:
risk = 0.6
elif task_type == TaskType.CODING:
risk = 0.5
elif task_type == TaskType.UNKNOWN_AMBIGUOUS:
risk = 0.7
# Adjust risk by complexity
risk = min(risk * complexity_mult, 1.0)
# Verifier required for high-risk or complex tasks
verifier_required = risk > 0.6 or task_type == TaskType.LEGAL_REGULATED
# Retrieval required for research, document, retrieval-heavy
retrieval_required = task_type in (
TaskType.RESEARCH,
TaskType.RETRIEVAL_HEAVY,
TaskType.DOCUMENT_DRAFTING,
TaskType.LEGAL_REGULATED,
)
expected_latency = expected_cost * 10000 # rough heuristic: $0.001 ~ 10s
return TaskPrediction(
task_type=task_type,
expected_cost=expected_cost,
expected_model_tier=expected_tier,
expected_tools_needed=expected_tools,
risk_of_failure=risk,
retrieval_required=retrieval_required,
verifier_required=verifier_required,
expected_latency_ms=expected_latency,
confidence=0.7 if matched_types else 0.4,
)
def classify_with_history(self, user_request: str, past_traces: List[Dict]) -> TaskPrediction:
"""Classify using historical trace data for this user/task pattern."""
base = self.classify(user_request)
if not past_traces:
return base
# Find similar past requests
similar = [
t for t in past_traces
if self._similarity(user_request, t.get("user_request", "")) > 0.5
]
if len(similar) >= 3:
# Adjust predictions based on history
avg_cost = sum(t.get("total_cost", base.expected_cost) for t in similar) / len(similar)
success_rate = sum(1 for t in similar if t.get("final_outcome") == "success") / len(similar)
avg_retries = sum(t.get("total_retries", 0) for t in similar) / len(similar)
# If history shows high failure, bump tier and require verifier
if success_rate < 0.5:
base = TaskPrediction(
task_type=base.task_type,
expected_cost=avg_cost * 1.2,
expected_model_tier=min(base.expected_model_tier + 1, 5),
expected_tools_needed=base.expected_tools_needed,
risk_of_failure=min(base.risk_of_failure * 1.3, 1.0),
retrieval_required=True,
verifier_required=True,
expected_latency_ms=base.expected_latency_ms * 1.2,
confidence=min(base.confidence + 0.1, 1.0),
)
else:
base = TaskPrediction(
task_type=base.task_type,
expected_cost=avg_cost * 0.9, # history suggests we can be cheaper
expected_model_tier=max(base.expected_model_tier - 1, 1),
expected_tools_needed=base.expected_tools_needed,
risk_of_failure=base.risk_of_failure * 0.8,
retrieval_required=base.retrieval_required,
verifier_required=base.verifier_required and avg_retries > 1,
expected_latency_ms=base.expected_latency_ms * 0.9,
confidence=min(base.confidence + 0.2, 1.0),
)
return base
@staticmethod
def _similarity(a: str, b: str) -> float:
"""Simple Jaccard similarity on words."""
words_a = set(a.lower().split())
words_b = set(b.lower().split())
if not words_a or not words_b:
return 0.0
return len(words_a & words_b) / len(words_a | words_b)