| """Task Cost Classifier - Module 2. |
| |
| Classifies incoming tasks by expected cost, risk, model strength needed, |
| and predicts whether retrieval/verifier is required. |
| """ |
|
|
| import re |
| from typing import Dict, List, Tuple, Optional |
| from dataclasses import dataclass |
|
|
| from .trace_schema import TaskType |
| from .config import ACOConfig |
|
|
|
|
| @dataclass |
| class TaskPrediction: |
| task_type: TaskType |
| expected_cost: float |
| expected_model_tier: int |
| expected_tools_needed: List[str] |
| risk_of_failure: float |
| retrieval_required: bool |
| verifier_required: bool |
| expected_latency_ms: float |
| confidence: float |
|
|
|
|
| class TaskCostClassifier: |
| """Classifies agent tasks into cost/risk categories.""" |
|
|
| |
| KEYWORD_MAP: Dict[str, Tuple[TaskType, float, int]] = { |
| |
| "what is": (TaskType.QUICK_ANSWER, 0.001, 1), |
| "define": (TaskType.QUICK_ANSWER, 0.001, 1), |
| "explain briefly": (TaskType.QUICK_ANSWER, 0.002, 1), |
| "summarize": (TaskType.QUICK_ANSWER, 0.005, 2), |
| "short answer": (TaskType.QUICK_ANSWER, 0.001, 1), |
| |
| "write code": (TaskType.CODING, 0.05, 3), |
| "fix bug": (TaskType.CODING, 0.08, 4), |
| "refactor": (TaskType.CODING, 0.03, 3), |
| "implement": (TaskType.CODING, 0.05, 3), |
| "test": (TaskType.CODING, 0.04, 3), |
| "debug": (TaskType.CODING, 0.06, 4), |
| "python": (TaskType.CODING, 0.03, 3), |
| "javascript": (TaskType.CODING, 0.03, 3), |
| "function": (TaskType.CODING, 0.02, 2), |
| |
| "research": (TaskType.RESEARCH, 0.15, 4), |
| "find sources": (TaskType.RESEARCH, 0.1, 3), |
| "literature review": (TaskType.RESEARCH, 0.2, 4), |
| "compare": (TaskType.RESEARCH, 0.08, 3), |
| "analyze": (TaskType.RESEARCH, 0.1, 3), |
| "investigate": (TaskType.RESEARCH, 0.12, 4), |
| |
| "draft": (TaskType.DOCUMENT_DRAFTING, 0.05, 3), |
| "write a document": (TaskType.DOCUMENT_DRAFTING, 0.06, 3), |
| "proposal": (TaskType.DOCUMENT_DRAFTING, 0.08, 3), |
| "report": (TaskType.DOCUMENT_DRAFTING, 0.1, 4), |
| "email": (TaskType.DOCUMENT_DRAFTING, 0.01, 2), |
| |
| "contract": (TaskType.LEGAL_REGULATED, 0.15, 5), |
| "legal": (TaskType.LEGAL_REGULATED, 0.15, 5), |
| "compliance": (TaskType.LEGAL_REGULATED, 0.12, 5), |
| "regulatory": (TaskType.LEGAL_REGULATED, 0.12, 5), |
| "privacy policy": (TaskType.LEGAL_REGULATED, 0.1, 5), |
| "terms of service": (TaskType.LEGAL_REGULATED, 0.1, 5), |
| |
| "search for": (TaskType.TOOL_HEAVY, 0.05, 3), |
| "look up": (TaskType.TOOL_HEAVY, 0.03, 2), |
| "fetch": (TaskType.TOOL_HEAVY, 0.04, 3), |
| "api": (TaskType.TOOL_HEAVY, 0.06, 3), |
| "database": (TaskType.TOOL_HEAVY, 0.05, 3), |
| "scrape": (TaskType.TOOL_HEAVY, 0.04, 3), |
| |
| "based on the document": (TaskType.RETRIEVAL_HEAVY, 0.08, 3), |
| "from my files": (TaskType.RETRIEVAL_HEAVY, 0.05, 3), |
| "rag": (TaskType.RETRIEVAL_HEAVY, 0.06, 3), |
| "retrieve": (TaskType.RETRIEVAL_HEAVY, 0.05, 3), |
| |
| "plan": (TaskType.LONG_HORIZON, 0.1, 4), |
| "project": (TaskType.LONG_HORIZON, 0.15, 4), |
| "over the next": (TaskType.LONG_HORIZON, 0.1, 4), |
| "multi-step": (TaskType.LONG_HORIZON, 0.08, 4), |
| "orchestrate": (TaskType.LONG_HORIZON, 0.12, 4), |
| } |
|
|
| |
| COMPLEXITY_PATTERNS = [ |
| (r"\b(AND|and)\b.*\b(AND|and)\b.*\b(AND|and)\b", 1.5), |
| (r"\bstep\s+\d+\b", 1.3), |
| (r"\d+\+\s*(pages|files|functions|tests)", 1.4), |
| (r"\b(entire|whole|all|every)\b", 1.2), |
| (r"\b(critical|production|live|deployed)\b", 1.5), |
| ] |
|
|
| def __init__(self, config: Optional[ACOConfig] = None): |
| self.config = config or ACOConfig() |
| self.history: List[Dict] = [] |
|
|
| def classify(self, user_request: str) -> TaskPrediction: |
| """Classify a user request into task type, cost, risk, etc.""" |
| request_lower = user_request.lower() |
| |
| |
| matched_types: Dict[TaskType, List[float]] = {} |
| for keyword, (task_type, base_cost, tier) in self.KEYWORD_MAP.items(): |
| if keyword in request_lower: |
| matched_types.setdefault(task_type, []).append(base_cost) |
| |
| |
| if not matched_types: |
| task_type = TaskType.UNKNOWN_AMBIGUOUS |
| base_cost = 0.05 |
| base_tier = 2 |
| else: |
| |
| task_type = max(matched_types.keys(), key=lambda t: sum(matched_types[t])) |
| base_cost = max(matched_types[task_type]) |
| base_tier = self.KEYWORD_MAP[ |
| max( |
| (k for k, (tt, _, _) in self.KEYWORD_MAP.items() if tt == task_type), |
| key=lambda k: base_cost if k in request_lower else 0, |
| ) |
| ][2] |
| |
| |
| complexity_mult = 1.0 |
| for pattern, mult in self.COMPLEXITY_PATTERNS: |
| if re.search(pattern, user_request, re.IGNORECASE): |
| complexity_mult = max(complexity_mult, mult) |
| |
| |
| word_count = len(request_lower.split()) |
| length_mult = 1.0 + min(word_count / 500, 0.5) |
| |
| expected_cost = base_cost * complexity_mult * length_mult |
| expected_tier = min(base_tier + int(complexity_mult > 1.2), 5) |
| |
| |
| expected_tools = [] |
| if task_type in (TaskType.RESEARCH, TaskType.TOOL_HEAVY, TaskType.RETRIEVAL_HEAVY): |
| expected_tools = ["search", "retrieve", "fetch"] |
| elif task_type == TaskType.CODING: |
| expected_tools = ["code_execution", "linter", "test_runner"] |
| elif task_type == TaskType.LEGAL_REGULATED: |
| expected_tools = ["document_retrieval", "compliance_check"] |
| |
| |
| risk = 0.3 |
| if task_type == TaskType.LEGAL_REGULATED: |
| risk = 0.8 |
| elif task_type == TaskType.LONG_HORIZON: |
| risk = 0.6 |
| elif task_type == TaskType.CODING: |
| risk = 0.5 |
| elif task_type == TaskType.UNKNOWN_AMBIGUOUS: |
| risk = 0.7 |
| |
| |
| risk = min(risk * complexity_mult, 1.0) |
| |
| |
| verifier_required = risk > 0.6 or task_type == TaskType.LEGAL_REGULATED |
| |
| |
| retrieval_required = task_type in ( |
| TaskType.RESEARCH, |
| TaskType.RETRIEVAL_HEAVY, |
| TaskType.DOCUMENT_DRAFTING, |
| TaskType.LEGAL_REGULATED, |
| ) |
| |
| expected_latency = expected_cost * 10000 |
| |
| return TaskPrediction( |
| task_type=task_type, |
| expected_cost=expected_cost, |
| expected_model_tier=expected_tier, |
| expected_tools_needed=expected_tools, |
| risk_of_failure=risk, |
| retrieval_required=retrieval_required, |
| verifier_required=verifier_required, |
| expected_latency_ms=expected_latency, |
| confidence=0.7 if matched_types else 0.4, |
| ) |
|
|
| def classify_with_history(self, user_request: str, past_traces: List[Dict]) -> TaskPrediction: |
| """Classify using historical trace data for this user/task pattern.""" |
| base = self.classify(user_request) |
| |
| if not past_traces: |
| return base |
| |
| |
| similar = [ |
| t for t in past_traces |
| if self._similarity(user_request, t.get("user_request", "")) > 0.5 |
| ] |
| |
| if len(similar) >= 3: |
| |
| avg_cost = sum(t.get("total_cost", base.expected_cost) for t in similar) / len(similar) |
| success_rate = sum(1 for t in similar if t.get("final_outcome") == "success") / len(similar) |
| avg_retries = sum(t.get("total_retries", 0) for t in similar) / len(similar) |
| |
| |
| if success_rate < 0.5: |
| base = TaskPrediction( |
| task_type=base.task_type, |
| expected_cost=avg_cost * 1.2, |
| expected_model_tier=min(base.expected_model_tier + 1, 5), |
| expected_tools_needed=base.expected_tools_needed, |
| risk_of_failure=min(base.risk_of_failure * 1.3, 1.0), |
| retrieval_required=True, |
| verifier_required=True, |
| expected_latency_ms=base.expected_latency_ms * 1.2, |
| confidence=min(base.confidence + 0.1, 1.0), |
| ) |
| else: |
| base = TaskPrediction( |
| task_type=base.task_type, |
| expected_cost=avg_cost * 0.9, |
| expected_model_tier=max(base.expected_model_tier - 1, 1), |
| expected_tools_needed=base.expected_tools_needed, |
| risk_of_failure=base.risk_of_failure * 0.8, |
| retrieval_required=base.retrieval_required, |
| verifier_required=base.verifier_required and avg_retries > 1, |
| expected_latency_ms=base.expected_latency_ms * 0.9, |
| confidence=min(base.confidence + 0.2, 1.0), |
| ) |
| |
| return base |
|
|
| @staticmethod |
| def _similarity(a: str, b: str) -> float: |
| """Simple Jaccard similarity on words.""" |
| words_a = set(a.lower().split()) |
| words_b = set(b.lower().split()) |
| if not words_a or not words_b: |
| return 0.0 |
| return len(words_a & words_b) / len(words_a | words_b) |
|
|