File size: 4,927 Bytes
1b0e9a1
 
a7c7186
 
1b0e9a1
 
 
 
 
 
 
 
 
 
a7c7186
1b0e9a1
 
 
 
a7c7186
1b0e9a1
 
 
 
a7c7186
1b0e9a1
 
 
 
 
a7c7186
 
1b0e9a1
 
a7c7186
1b0e9a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c7186
1b0e9a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c7186
1b0e9a1
 
 
 
 
 
 
 
a7c7186
1b0e9a1
 
a7c7186
1b0e9a1
 
 
 
 
a7c7186
1b0e9a1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Task Cost Classifier: Predicts task type, difficulty, and cost requirements."""
from typing import Dict, Tuple, Optional
import re

CODE_PATTERNS = [r'\b(code|function|bug|debug|refactor|implement|compile|runtime|segfault|thread|async|class|module|python|javascript|typescript|go|rust|java)\b']
LEGAL_PATTERNS = [r'\b(contract|legal|compliance|gdpr|privacy|policy|regulatory|liability|indemnif|clause)\b']
RESEARCH_PATTERNS = [r'\b(research|sources?|literature|investigate|compare|analy[sz]e|survey|paper|arxiv|find)\b']
TOOL_PATTERNS = [r'\b(search|fetch|retrieve|query|api|database|scrape|aggregate|list|download)\b']
LONG_PATTERNS = [r'\b(plan|roadmap|orchestrat|migrate|pipeline|deploy|architecture|multi-step|end.to.end|entire)\b']
MATH_PATTERNS = [r'\b(calculat|comput|solve|equation|formula|optim[iy]|probability|integral|derivative)\b']
SIMPLE_PATTERNS = [r'\b(typo|simple|quick|brief|just|minor|small|easy|trivial|clarif|only)\b']
CRITICAL_PATTERNS = [r'\b(critical|production|urgent|now|emergency|live|deployed|safety|security|important)\b']
DOC_PATTERNS = [r'\b(draft|write|compose|email|proposal|report|memo|letter|document|create)\b']
RETRIEVAL_PATTERNS = [r'\b(find all|search.*for|look up|based on|according to|in the document|in the file)\b']

TASK_TYPES = [
    "quick_answer", "coding", "research", "document_drafting",
    "legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous"
]

TASK_DIFFICULTY_BASE = {
    "quick_answer": 1, "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2,
    "research": 3, "coding": 3, "unknown_ambiguous": 3, "long_horizon": 4, "legal_regulated": 5,
}

TASK_RISK = {
    "quick_answer": "low", "document_drafting": "low", "tool_heavy": "medium",
    "retrieval_heavy": "medium", "research": "medium", "coding": "medium",
    "unknown_ambiguous": "medium", "long_horizon": "high", "legal_regulated": "critical",
}

class TaskCostClassifier:
    def __init__(self):
        self.task_types = TASK_TYPES

    def classify(self, request: str) -> Dict:
        task_type = self._classify_type(request)
        difficulty = self._estimate_difficulty(request, task_type)
        risk = TASK_RISK.get(task_type, "medium")
        needs_tools = self._needs_tools(request, task_type)
        needs_retrieval = self._needs_retrieval(request, task_type)
        needs_verifier = self._needs_verifier(request, task_type, risk)
        expected_cost = self._estimate_cost(difficulty, needs_tools, needs_retrieval, needs_verifier)
        return {
            "task_type": task_type,
            "difficulty": difficulty,
            "risk": risk,
            "needs_tools": needs_tools,
            "needs_retrieval": needs_retrieval,
            "needs_verifier": needs_verifier,
            "expected_cost": expected_cost,
            "expected_tier": min(difficulty + 1, 5),
        }

    def _classify_type(self, request: str) -> str:
        r = request.lower()
        scores = {}
        scores["legal_regulated"] = sum(len(re.findall(p, r)) for p in LEGAL_PATTERNS)
        scores["coding"] = sum(len(re.findall(p, r)) for p in CODE_PATTERNS)
        scores["research"] = sum(len(re.findall(p, r)) for p in RESEARCH_PATTERNS)
        scores["tool_heavy"] = sum(len(re.findall(p, r)) for p in TOOL_PATTERNS)
        scores["long_horizon"] = sum(len(re.findall(p, r)) for p in LONG_PATTERNS)
        scores["retrieval_heavy"] = sum(len(re.findall(p, r)) for p in RETRIEVAL_PATTERNS)
        scores["document_drafting"] = sum(len(re.findall(p, r)) for p in DOC_PATTERNS)
        scores["quick_answer"] = 0.5 if len(r.split()) < 10 else 0
        # Check if no strong signal
        max_score = max(scores.values()) if scores else 0
        if max_score == 0:
            return "unknown_ambiguous"
        return max(scores, key=scores.get)

    def _estimate_difficulty(self, request: str, task_type: str) -> int:
        r = request.lower()
        base = TASK_DIFFICULTY_BASE.get(task_type, 3)
        if any(re.findall(p, r) for p in CRITICAL_PATTERNS):
            base = min(base + 1, 5)
        if any(re.findall(p, r) for p in SIMPLE_PATTERNS):
            base = max(base - 1, 1)
        return base

    def _needs_tools(self, request: str, task_type: str) -> bool:
        return task_type in ("tool_heavy", "retrieval_heavy", "coding", "research")

    def _needs_retrieval(self, request: str, task_type: str) -> bool:
        return task_type in ("retrieval_heavy", "research")

    def _needs_verifier(self, request: str, task_type: str, risk: str) -> bool:
        return risk in ("high", "critical")

    def _estimate_cost(self, difficulty: int, tools: bool, retrieval: bool, verifier: bool) -> float:
        base_cost = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(difficulty, 1.0)
        if tools: base_cost *= 1.3
        if retrieval: base_cost *= 1.2
        if verifier: base_cost *= 1.1
        return base_cost