"""Task Cost Classifier - Module 2. Classifies incoming tasks by expected cost, risk, model strength needed, and predicts whether retrieval/verifier is required. """ import re from typing import Dict, List, Tuple, Optional from dataclasses import dataclass from .trace_schema import TaskType from .config import ACOConfig @dataclass class TaskPrediction: task_type: TaskType expected_cost: float expected_model_tier: int # 1-5 expected_tools_needed: List[str] risk_of_failure: float # 0-1 retrieval_required: bool verifier_required: bool expected_latency_ms: float confidence: float class TaskCostClassifier: """Classifies agent tasks into cost/risk categories.""" # Keywords mapped to task types with base cost estimates KEYWORD_MAP: Dict[str, Tuple[TaskType, float, int]] = { # quick_answer: low cost, tier 1-2 "what is": (TaskType.QUICK_ANSWER, 0.001, 1), "define": (TaskType.QUICK_ANSWER, 0.001, 1), "explain briefly": (TaskType.QUICK_ANSWER, 0.002, 1), "summarize": (TaskType.QUICK_ANSWER, 0.005, 2), "short answer": (TaskType.QUICK_ANSWER, 0.001, 1), # coding: medium-high cost, tier 3-4 "write code": (TaskType.CODING, 0.05, 3), "fix bug": (TaskType.CODING, 0.08, 4), "refactor": (TaskType.CODING, 0.03, 3), "implement": (TaskType.CODING, 0.05, 3), "test": (TaskType.CODING, 0.04, 3), "debug": (TaskType.CODING, 0.06, 4), "python": (TaskType.CODING, 0.03, 3), "javascript": (TaskType.CODING, 0.03, 3), "function": (TaskType.CODING, 0.02, 2), # research: high cost, tier 3-4 "research": (TaskType.RESEARCH, 0.15, 4), "find sources": (TaskType.RESEARCH, 0.1, 3), "literature review": (TaskType.RESEARCH, 0.2, 4), "compare": (TaskType.RESEARCH, 0.08, 3), "analyze": (TaskType.RESEARCH, 0.1, 3), "investigate": (TaskType.RESEARCH, 0.12, 4), # document_drafting: medium cost, tier 3 "draft": (TaskType.DOCUMENT_DRAFTING, 0.05, 3), "write a document": (TaskType.DOCUMENT_DRAFTING, 0.06, 3), "proposal": (TaskType.DOCUMENT_DRAFTING, 0.08, 3), "report": (TaskType.DOCUMENT_DRAFTING, 0.1, 4), "email": (TaskType.DOCUMENT_DRAFTING, 0.01, 2), # legal_regulated: high cost, tier 4-5 "contract": (TaskType.LEGAL_REGULATED, 0.15, 5), "legal": (TaskType.LEGAL_REGULATED, 0.15, 5), "compliance": (TaskType.LEGAL_REGULATED, 0.12, 5), "regulatory": (TaskType.LEGAL_REGULATED, 0.12, 5), "privacy policy": (TaskType.LEGAL_REGULATED, 0.1, 5), "terms of service": (TaskType.LEGAL_REGULATED, 0.1, 5), # tool_heavy "search for": (TaskType.TOOL_HEAVY, 0.05, 3), "look up": (TaskType.TOOL_HEAVY, 0.03, 2), "fetch": (TaskType.TOOL_HEAVY, 0.04, 3), "api": (TaskType.TOOL_HEAVY, 0.06, 3), "database": (TaskType.TOOL_HEAVY, 0.05, 3), "scrape": (TaskType.TOOL_HEAVY, 0.04, 3), # retrieval_heavy "based on the document": (TaskType.RETRIEVAL_HEAVY, 0.08, 3), "from my files": (TaskType.RETRIEVAL_HEAVY, 0.05, 3), "rag": (TaskType.RETRIEVAL_HEAVY, 0.06, 3), "retrieve": (TaskType.RETRIEVAL_HEAVY, 0.05, 3), # long_horizon "plan": (TaskType.LONG_HORIZON, 0.1, 4), "project": (TaskType.LONG_HORIZON, 0.15, 4), "over the next": (TaskType.LONG_HORIZON, 0.1, 4), "multi-step": (TaskType.LONG_HORIZON, 0.08, 4), "orchestrate": (TaskType.LONG_HORIZON, 0.12, 4), } # Complexity multipliers based on length and structure COMPLEXITY_PATTERNS = [ (r"\b(AND|and)\b.*\b(AND|and)\b.*\b(AND|and)\b", 1.5), # multiple sub-tasks (r"\bstep\s+\d+\b", 1.3), (r"\d+\+\s*(pages|files|functions|tests)", 1.4), (r"\b(entire|whole|all|every)\b", 1.2), (r"\b(critical|production|live|deployed)\b", 1.5), ] def __init__(self, config: Optional[ACOConfig] = None): self.config = config or ACOConfig() self.history: List[Dict] = [] def classify(self, user_request: str) -> TaskPrediction: """Classify a user request into task type, cost, risk, etc.""" request_lower = user_request.lower() # Find best matching keywords matched_types: Dict[TaskType, List[float]] = {} for keyword, (task_type, base_cost, tier) in self.KEYWORD_MAP.items(): if keyword in request_lower: matched_types.setdefault(task_type, []).append(base_cost) # Default to unknown if no match if not matched_types: task_type = TaskType.UNKNOWN_AMBIGUOUS base_cost = 0.05 base_tier = 2 else: # Pick task type with highest cumulative base cost (most specific) task_type = max(matched_types.keys(), key=lambda t: sum(matched_types[t])) base_cost = max(matched_types[task_type]) base_tier = self.KEYWORD_MAP[ max( (k for k, (tt, _, _) in self.KEYWORD_MAP.items() if tt == task_type), key=lambda k: base_cost if k in request_lower else 0, ) ][2] # Apply complexity multipliers complexity_mult = 1.0 for pattern, mult in self.COMPLEXITY_PATTERNS: if re.search(pattern, user_request, re.IGNORECASE): complexity_mult = max(complexity_mult, mult) # Length factor word_count = len(request_lower.split()) length_mult = 1.0 + min(word_count / 500, 0.5) expected_cost = base_cost * complexity_mult * length_mult expected_tier = min(base_tier + int(complexity_mult > 1.2), 5) # Determine tool needs expected_tools = [] if task_type in (TaskType.RESEARCH, TaskType.TOOL_HEAVY, TaskType.RETRIEVAL_HEAVY): expected_tools = ["search", "retrieve", "fetch"] elif task_type == TaskType.CODING: expected_tools = ["code_execution", "linter", "test_runner"] elif task_type == TaskType.LEGAL_REGULATED: expected_tools = ["document_retrieval", "compliance_check"] # Risk estimation risk = 0.3 if task_type == TaskType.LEGAL_REGULATED: risk = 0.8 elif task_type == TaskType.LONG_HORIZON: risk = 0.6 elif task_type == TaskType.CODING: risk = 0.5 elif task_type == TaskType.UNKNOWN_AMBIGUOUS: risk = 0.7 # Adjust risk by complexity risk = min(risk * complexity_mult, 1.0) # Verifier required for high-risk or complex tasks verifier_required = risk > 0.6 or task_type == TaskType.LEGAL_REGULATED # Retrieval required for research, document, retrieval-heavy retrieval_required = task_type in ( TaskType.RESEARCH, TaskType.RETRIEVAL_HEAVY, TaskType.DOCUMENT_DRAFTING, TaskType.LEGAL_REGULATED, ) expected_latency = expected_cost * 10000 # rough heuristic: $0.001 ~ 10s return TaskPrediction( task_type=task_type, expected_cost=expected_cost, expected_model_tier=expected_tier, expected_tools_needed=expected_tools, risk_of_failure=risk, retrieval_required=retrieval_required, verifier_required=verifier_required, expected_latency_ms=expected_latency, confidence=0.7 if matched_types else 0.4, ) def classify_with_history(self, user_request: str, past_traces: List[Dict]) -> TaskPrediction: """Classify using historical trace data for this user/task pattern.""" base = self.classify(user_request) if not past_traces: return base # Find similar past requests similar = [ t for t in past_traces if self._similarity(user_request, t.get("user_request", "")) > 0.5 ] if len(similar) >= 3: # Adjust predictions based on history avg_cost = sum(t.get("total_cost", base.expected_cost) for t in similar) / len(similar) success_rate = sum(1 for t in similar if t.get("final_outcome") == "success") / len(similar) avg_retries = sum(t.get("total_retries", 0) for t in similar) / len(similar) # If history shows high failure, bump tier and require verifier if success_rate < 0.5: base = TaskPrediction( task_type=base.task_type, expected_cost=avg_cost * 1.2, expected_model_tier=min(base.expected_model_tier + 1, 5), expected_tools_needed=base.expected_tools_needed, risk_of_failure=min(base.risk_of_failure * 1.3, 1.0), retrieval_required=True, verifier_required=True, expected_latency_ms=base.expected_latency_ms * 1.2, confidence=min(base.confidence + 0.1, 1.0), ) else: base = TaskPrediction( task_type=base.task_type, expected_cost=avg_cost * 0.9, # history suggests we can be cheaper expected_model_tier=max(base.expected_model_tier - 1, 1), expected_tools_needed=base.expected_tools_needed, risk_of_failure=base.risk_of_failure * 0.8, retrieval_required=base.retrieval_required, verifier_required=base.verifier_required and avg_retries > 1, expected_latency_ms=base.expected_latency_ms * 0.9, confidence=min(base.confidence + 0.2, 1.0), ) return base @staticmethod def _similarity(a: str, b: str) -> float: """Simple Jaccard similarity on words.""" words_a = set(a.lower().split()) words_b = set(b.lower().split()) if not words_a or not words_b: return 0.0 return len(words_a & words_b) / len(words_a | words_b)