"""Task Cost Classifier: Predicts task type, difficulty, and cost requirements.""" from typing import Dict, Tuple, Optional import re CODE_PATTERNS = [r'\b(code|function|bug|debug|refactor|implement|compile|runtime|segfault|thread|async|class|module|python|javascript|typescript|go|rust|java)\b'] LEGAL_PATTERNS = [r'\b(contract|legal|compliance|gdpr|privacy|policy|regulatory|liability|indemnif|clause)\b'] RESEARCH_PATTERNS = [r'\b(research|sources?|literature|investigate|compare|analy[sz]e|survey|paper|arxiv|find)\b'] TOOL_PATTERNS = [r'\b(search|fetch|retrieve|query|api|database|scrape|aggregate|list|download)\b'] LONG_PATTERNS = [r'\b(plan|roadmap|orchestrat|migrate|pipeline|deploy|architecture|multi-step|end.to.end|entire)\b'] MATH_PATTERNS = [r'\b(calculat|comput|solve|equation|formula|optim[iy]|probability|integral|derivative)\b'] SIMPLE_PATTERNS = [r'\b(typo|simple|quick|brief|just|minor|small|easy|trivial|clarif|only)\b'] CRITICAL_PATTERNS = [r'\b(critical|production|urgent|now|emergency|live|deployed|safety|security|important)\b'] DOC_PATTERNS = [r'\b(draft|write|compose|email|proposal|report|memo|letter|document|create)\b'] RETRIEVAL_PATTERNS = [r'\b(find all|search.*for|look up|based on|according to|in the document|in the file)\b'] TASK_TYPES = [ "quick_answer", "coding", "research", "document_drafting", "legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous" ] TASK_DIFFICULTY_BASE = { "quick_answer": 1, "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2, "research": 3, "coding": 3, "unknown_ambiguous": 3, "long_horizon": 4, "legal_regulated": 5, } TASK_RISK = { "quick_answer": "low", "document_drafting": "low", "tool_heavy": "medium", "retrieval_heavy": "medium", "research": "medium", "coding": "medium", "unknown_ambiguous": "medium", "long_horizon": "high", "legal_regulated": "critical", } class TaskCostClassifier: def __init__(self): self.task_types = TASK_TYPES def classify(self, request: str) -> Dict: task_type = self._classify_type(request) difficulty = self._estimate_difficulty(request, task_type) risk = TASK_RISK.get(task_type, "medium") needs_tools = self._needs_tools(request, task_type) needs_retrieval = self._needs_retrieval(request, task_type) needs_verifier = self._needs_verifier(request, task_type, risk) expected_cost = self._estimate_cost(difficulty, needs_tools, needs_retrieval, needs_verifier) return { "task_type": task_type, "difficulty": difficulty, "risk": risk, "needs_tools": needs_tools, "needs_retrieval": needs_retrieval, "needs_verifier": needs_verifier, "expected_cost": expected_cost, "expected_tier": min(difficulty + 1, 5), } def _classify_type(self, request: str) -> str: r = request.lower() scores = {} scores["legal_regulated"] = sum(len(re.findall(p, r)) for p in LEGAL_PATTERNS) scores["coding"] = sum(len(re.findall(p, r)) for p in CODE_PATTERNS) scores["research"] = sum(len(re.findall(p, r)) for p in RESEARCH_PATTERNS) scores["tool_heavy"] = sum(len(re.findall(p, r)) for p in TOOL_PATTERNS) scores["long_horizon"] = sum(len(re.findall(p, r)) for p in LONG_PATTERNS) scores["retrieval_heavy"] = sum(len(re.findall(p, r)) for p in RETRIEVAL_PATTERNS) scores["document_drafting"] = sum(len(re.findall(p, r)) for p in DOC_PATTERNS) scores["quick_answer"] = 0.5 if len(r.split()) < 10 else 0 # Check if no strong signal max_score = max(scores.values()) if scores else 0 if max_score == 0: return "unknown_ambiguous" return max(scores, key=scores.get) def _estimate_difficulty(self, request: str, task_type: str) -> int: r = request.lower() base = TASK_DIFFICULTY_BASE.get(task_type, 3) if any(re.findall(p, r) for p in CRITICAL_PATTERNS): base = min(base + 1, 5) if any(re.findall(p, r) for p in SIMPLE_PATTERNS): base = max(base - 1, 1) return base def _needs_tools(self, request: str, task_type: str) -> bool: return task_type in ("tool_heavy", "retrieval_heavy", "coding", "research") def _needs_retrieval(self, request: str, task_type: str) -> bool: return task_type in ("retrieval_heavy", "research") def _needs_verifier(self, request: str, task_type: str, risk: str) -> bool: return risk in ("high", "critical") def _estimate_cost(self, difficulty: int, tools: bool, retrieval: bool, verifier: bool) -> float: base_cost = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(difficulty, 1.0) if tools: base_cost *= 1.3 if retrieval: base_cost *= 1.2 if verifier: base_cost *= 1.1 return base_cost