narcolepticchicken commited on
Commit
1b0e9a1
·
verified ·
1 Parent(s): 5569d72

Upload aco/classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/classifier.py +82 -230
aco/classifier.py CHANGED
@@ -1,243 +1,95 @@
1
- """Task Cost Classifier - Module 2.
2
-
3
- Classifies incoming tasks by expected cost, risk, model strength needed,
4
- and predicts whether retrieval/verifier is required.
5
- """
6
-
7
  import re
8
- from typing import Dict, List, Tuple, Optional
9
- from dataclasses import dataclass
10
 
11
- from .trace_schema import TaskType
12
- from .config import ACOConfig
 
 
 
 
 
 
 
 
13
 
 
 
 
 
14
 
15
- @dataclass
16
- class TaskPrediction:
17
- task_type: TaskType
18
- expected_cost: float
19
- expected_model_tier: int # 1-5
20
- expected_tools_needed: List[str]
21
- risk_of_failure: float # 0-1
22
- retrieval_required: bool
23
- verifier_required: bool
24
- expected_latency_ms: float
25
- confidence: float
26
 
 
 
 
 
 
27
 
28
  class TaskCostClassifier:
29
- """Classifies agent tasks into cost/risk categories."""
 
30
 
31
- # Keywords mapped to task types with base cost estimates
32
- KEYWORD_MAP: Dict[str, Tuple[TaskType, float, int]] = {
33
- # quick_answer: low cost, tier 1-2
34
- "what is": (TaskType.QUICK_ANSWER, 0.001, 1),
35
- "define": (TaskType.QUICK_ANSWER, 0.001, 1),
36
- "explain briefly": (TaskType.QUICK_ANSWER, 0.002, 1),
37
- "summarize": (TaskType.QUICK_ANSWER, 0.005, 2),
38
- "short answer": (TaskType.QUICK_ANSWER, 0.001, 1),
39
- # coding: medium-high cost, tier 3-4
40
- "write code": (TaskType.CODING, 0.05, 3),
41
- "fix bug": (TaskType.CODING, 0.08, 4),
42
- "refactor": (TaskType.CODING, 0.03, 3),
43
- "implement": (TaskType.CODING, 0.05, 3),
44
- "test": (TaskType.CODING, 0.04, 3),
45
- "debug": (TaskType.CODING, 0.06, 4),
46
- "python": (TaskType.CODING, 0.03, 3),
47
- "javascript": (TaskType.CODING, 0.03, 3),
48
- "function": (TaskType.CODING, 0.02, 2),
49
- # research: high cost, tier 3-4
50
- "research": (TaskType.RESEARCH, 0.15, 4),
51
- "find sources": (TaskType.RESEARCH, 0.1, 3),
52
- "literature review": (TaskType.RESEARCH, 0.2, 4),
53
- "compare": (TaskType.RESEARCH, 0.08, 3),
54
- "analyze": (TaskType.RESEARCH, 0.1, 3),
55
- "investigate": (TaskType.RESEARCH, 0.12, 4),
56
- # document_drafting: medium cost, tier 3
57
- "draft": (TaskType.DOCUMENT_DRAFTING, 0.05, 3),
58
- "write a document": (TaskType.DOCUMENT_DRAFTING, 0.06, 3),
59
- "proposal": (TaskType.DOCUMENT_DRAFTING, 0.08, 3),
60
- "report": (TaskType.DOCUMENT_DRAFTING, 0.1, 4),
61
- "email": (TaskType.DOCUMENT_DRAFTING, 0.01, 2),
62
- # legal_regulated: high cost, tier 4-5
63
- "contract": (TaskType.LEGAL_REGULATED, 0.15, 5),
64
- "legal": (TaskType.LEGAL_REGULATED, 0.15, 5),
65
- "compliance": (TaskType.LEGAL_REGULATED, 0.12, 5),
66
- "regulatory": (TaskType.LEGAL_REGULATED, 0.12, 5),
67
- "privacy policy": (TaskType.LEGAL_REGULATED, 0.1, 5),
68
- "terms of service": (TaskType.LEGAL_REGULATED, 0.1, 5),
69
- # tool_heavy
70
- "search for": (TaskType.TOOL_HEAVY, 0.05, 3),
71
- "look up": (TaskType.TOOL_HEAVY, 0.03, 2),
72
- "fetch": (TaskType.TOOL_HEAVY, 0.04, 3),
73
- "api": (TaskType.TOOL_HEAVY, 0.06, 3),
74
- "database": (TaskType.TOOL_HEAVY, 0.05, 3),
75
- "scrape": (TaskType.TOOL_HEAVY, 0.04, 3),
76
- # retrieval_heavy
77
- "based on the document": (TaskType.RETRIEVAL_HEAVY, 0.08, 3),
78
- "from my files": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
79
- "rag": (TaskType.RETRIEVAL_HEAVY, 0.06, 3),
80
- "retrieve": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
81
- # long_horizon
82
- "plan": (TaskType.LONG_HORIZON, 0.1, 4),
83
- "project": (TaskType.LONG_HORIZON, 0.15, 4),
84
- "over the next": (TaskType.LONG_HORIZON, 0.1, 4),
85
- "multi-step": (TaskType.LONG_HORIZON, 0.08, 4),
86
- "orchestrate": (TaskType.LONG_HORIZON, 0.12, 4),
87
- }
88
 
89
- # Complexity multipliers based on length and structure
90
- COMPLEXITY_PATTERNS = [
91
- (r"\b(AND|and)\b.*\b(AND|and)\b.*\b(AND|and)\b", 1.5), # multiple sub-tasks
92
- (r"\bstep\s+\d+\b", 1.3),
93
- (r"\d+\+\s*(pages|files|functions|tests)", 1.4),
94
- (r"\b(entire|whole|all|every)\b", 1.2),
95
- (r"\b(critical|production|live|deployed)\b", 1.5),
96
- ]
 
 
 
 
 
 
 
 
97
 
98
- def __init__(self, config: Optional[ACOConfig] = None):
99
- self.config = config or ACOConfig()
100
- self.history: List[Dict] = []
 
 
 
 
 
101
 
102
- def classify(self, user_request: str) -> TaskPrediction:
103
- """Classify a user request into task type, cost, risk, etc."""
104
- request_lower = user_request.lower()
105
-
106
- # Find best matching keywords
107
- matched_types: Dict[TaskType, List[float]] = {}
108
- for keyword, (task_type, base_cost, tier) in self.KEYWORD_MAP.items():
109
- if keyword in request_lower:
110
- matched_types.setdefault(task_type, []).append(base_cost)
111
-
112
- # Default to unknown if no match
113
- if not matched_types:
114
- task_type = TaskType.UNKNOWN_AMBIGUOUS
115
- base_cost = 0.05
116
- base_tier = 2
117
- else:
118
- # Pick task type with highest cumulative base cost (most specific)
119
- task_type = max(matched_types.keys(), key=lambda t: sum(matched_types[t]))
120
- base_cost = max(matched_types[task_type])
121
- base_tier = self.KEYWORD_MAP[
122
- max(
123
- (k for k, (tt, _, _) in self.KEYWORD_MAP.items() if tt == task_type),
124
- key=lambda k: base_cost if k in request_lower else 0,
125
- )
126
- ][2]
127
-
128
- # Apply complexity multipliers
129
- complexity_mult = 1.0
130
- for pattern, mult in self.COMPLEXITY_PATTERNS:
131
- if re.search(pattern, user_request, re.IGNORECASE):
132
- complexity_mult = max(complexity_mult, mult)
133
-
134
- # Length factor
135
- word_count = len(request_lower.split())
136
- length_mult = 1.0 + min(word_count / 500, 0.5)
137
-
138
- expected_cost = base_cost * complexity_mult * length_mult
139
- expected_tier = min(base_tier + int(complexity_mult > 1.2), 5)
140
-
141
- # Determine tool needs
142
- expected_tools = []
143
- if task_type in (TaskType.RESEARCH, TaskType.TOOL_HEAVY, TaskType.RETRIEVAL_HEAVY):
144
- expected_tools = ["search", "retrieve", "fetch"]
145
- elif task_type == TaskType.CODING:
146
- expected_tools = ["code_execution", "linter", "test_runner"]
147
- elif task_type == TaskType.LEGAL_REGULATED:
148
- expected_tools = ["document_retrieval", "compliance_check"]
149
-
150
- # Risk estimation
151
- risk = 0.3
152
- if task_type == TaskType.LEGAL_REGULATED:
153
- risk = 0.8
154
- elif task_type == TaskType.LONG_HORIZON:
155
- risk = 0.6
156
- elif task_type == TaskType.CODING:
157
- risk = 0.5
158
- elif task_type == TaskType.UNKNOWN_AMBIGUOUS:
159
- risk = 0.7
160
-
161
- # Adjust risk by complexity
162
- risk = min(risk * complexity_mult, 1.0)
163
-
164
- # Verifier required for high-risk or complex tasks
165
- verifier_required = risk > 0.6 or task_type == TaskType.LEGAL_REGULATED
166
-
167
- # Retrieval required for research, document, retrieval-heavy
168
- retrieval_required = task_type in (
169
- TaskType.RESEARCH,
170
- TaskType.RETRIEVAL_HEAVY,
171
- TaskType.DOCUMENT_DRAFTING,
172
- TaskType.LEGAL_REGULATED,
173
- )
174
-
175
- expected_latency = expected_cost * 10000 # rough heuristic: $0.001 ~ 10s
176
-
177
- return TaskPrediction(
178
- task_type=task_type,
179
- expected_cost=expected_cost,
180
- expected_model_tier=expected_tier,
181
- expected_tools_needed=expected_tools,
182
- risk_of_failure=risk,
183
- retrieval_required=retrieval_required,
184
- verifier_required=verifier_required,
185
- expected_latency_ms=expected_latency,
186
- confidence=0.7 if matched_types else 0.4,
187
- )
188
 
189
- def classify_with_history(self, user_request: str, past_traces: List[Dict]) -> TaskPrediction:
190
- """Classify using historical trace data for this user/task pattern."""
191
- base = self.classify(user_request)
192
-
193
- if not past_traces:
194
- return base
195
-
196
- # Find similar past requests
197
- similar = [
198
- t for t in past_traces
199
- if self._similarity(user_request, t.get("user_request", "")) > 0.5
200
- ]
201
-
202
- if len(similar) >= 3:
203
- # Adjust predictions based on history
204
- avg_cost = sum(t.get("total_cost", base.expected_cost) for t in similar) / len(similar)
205
- success_rate = sum(1 for t in similar if t.get("final_outcome") == "success") / len(similar)
206
- avg_retries = sum(t.get("total_retries", 0) for t in similar) / len(similar)
207
-
208
- # If history shows high failure, bump tier and require verifier
209
- if success_rate < 0.5:
210
- base = TaskPrediction(
211
- task_type=base.task_type,
212
- expected_cost=avg_cost * 1.2,
213
- expected_model_tier=min(base.expected_model_tier + 1, 5),
214
- expected_tools_needed=base.expected_tools_needed,
215
- risk_of_failure=min(base.risk_of_failure * 1.3, 1.0),
216
- retrieval_required=True,
217
- verifier_required=True,
218
- expected_latency_ms=base.expected_latency_ms * 1.2,
219
- confidence=min(base.confidence + 0.1, 1.0),
220
- )
221
- else:
222
- base = TaskPrediction(
223
- task_type=base.task_type,
224
- expected_cost=avg_cost * 0.9, # history suggests we can be cheaper
225
- expected_model_tier=max(base.expected_model_tier - 1, 1),
226
- expected_tools_needed=base.expected_tools_needed,
227
- risk_of_failure=base.risk_of_failure * 0.8,
228
- retrieval_required=base.retrieval_required,
229
- verifier_required=base.verifier_required and avg_retries > 1,
230
- expected_latency_ms=base.expected_latency_ms * 0.9,
231
- confidence=min(base.confidence + 0.2, 1.0),
232
- )
233
-
234
- return base
235
 
236
- @staticmethod
237
- def _similarity(a: str, b: str) -> float:
238
- """Simple Jaccard similarity on words."""
239
- words_a = set(a.lower().split())
240
- words_b = set(b.lower().split())
241
- if not words_a or not words_b:
242
- return 0.0
243
- return len(words_a & words_b) / len(words_a | words_b)
 
1
+ """Task Cost Classifier: Predicts task type, difficulty, and cost requirements."""
2
+ from typing import Dict, Tuple, Optional
 
 
 
 
3
  import re
 
 
4
 
5
+ CODE_PATTERNS = [r'\b(code|function|bug|debug|refactor|implement|compile|runtime|segfault|thread|async|class|module|python|javascript|typescript|go|rust|java)\b']
6
+ LEGAL_PATTERNS = [r'\b(contract|legal|compliance|gdpr|privacy|policy|regulatory|liability|indemnif|clause)\b']
7
+ RESEARCH_PATTERNS = [r'\b(research|sources?|literature|investigate|compare|analy[sz]e|survey|paper|arxiv|find)\b']
8
+ TOOL_PATTERNS = [r'\b(search|fetch|retrieve|query|api|database|scrape|aggregate|list|download)\b']
9
+ LONG_PATTERNS = [r'\b(plan|roadmap|orchestrat|migrate|pipeline|deploy|architecture|multi-step|end.to.end|entire)\b']
10
+ MATH_PATTERNS = [r'\b(calculat|comput|solve|equation|formula|optim[iy]|probability|integral|derivative)\b']
11
+ SIMPLE_PATTERNS = [r'\b(typo|simple|quick|brief|just|minor|small|easy|trivial|clarif|only)\b']
12
+ CRITICAL_PATTERNS = [r'\b(critical|production|urgent|now|emergency|live|deployed|safety|security|important)\b']
13
+ DOC_PATTERNS = [r'\b(draft|write|compose|email|proposal|report|memo|letter|document|create)\b']
14
+ RETRIEVAL_PATTERNS = [r'\b(find all|search.*for|look up|based on|according to|in the document|in the file)\b']
15
 
16
+ TASK_TYPES = [
17
+ "quick_answer", "coding", "research", "document_drafting",
18
+ "legal_regulated", "tool_heavy", "retrieval_heavy", "long_horizon", "unknown_ambiguous"
19
+ ]
20
 
21
+ TASK_DIFFICULTY_BASE = {
22
+ "quick_answer": 1, "document_drafting": 2, "tool_heavy": 2, "retrieval_heavy": 2,
23
+ "research": 3, "coding": 3, "unknown_ambiguous": 3, "long_horizon": 4, "legal_regulated": 5,
24
+ }
 
 
 
 
 
 
 
25
 
26
+ TASK_RISK = {
27
+ "quick_answer": "low", "document_drafting": "low", "tool_heavy": "medium",
28
+ "retrieval_heavy": "medium", "research": "medium", "coding": "medium",
29
+ "unknown_ambiguous": "medium", "long_horizon": "high", "legal_regulated": "critical",
30
+ }
31
 
32
  class TaskCostClassifier:
33
+ def __init__(self):
34
+ self.task_types = TASK_TYPES
35
 
36
+ def classify(self, request: str) -> Dict:
37
+ task_type = self._classify_type(request)
38
+ difficulty = self._estimate_difficulty(request, task_type)
39
+ risk = TASK_RISK.get(task_type, "medium")
40
+ needs_tools = self._needs_tools(request, task_type)
41
+ needs_retrieval = self._needs_retrieval(request, task_type)
42
+ needs_verifier = self._needs_verifier(request, task_type, risk)
43
+ expected_cost = self._estimate_cost(difficulty, needs_tools, needs_retrieval, needs_verifier)
44
+ return {
45
+ "task_type": task_type,
46
+ "difficulty": difficulty,
47
+ "risk": risk,
48
+ "needs_tools": needs_tools,
49
+ "needs_retrieval": needs_retrieval,
50
+ "needs_verifier": needs_verifier,
51
+ "expected_cost": expected_cost,
52
+ "expected_tier": min(difficulty + 1, 5),
53
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ def _classify_type(self, request: str) -> str:
56
+ r = request.lower()
57
+ scores = {}
58
+ scores["legal_regulated"] = sum(len(re.findall(p, r)) for p in LEGAL_PATTERNS)
59
+ scores["coding"] = sum(len(re.findall(p, r)) for p in CODE_PATTERNS)
60
+ scores["research"] = sum(len(re.findall(p, r)) for p in RESEARCH_PATTERNS)
61
+ scores["tool_heavy"] = sum(len(re.findall(p, r)) for p in TOOL_PATTERNS)
62
+ scores["long_horizon"] = sum(len(re.findall(p, r)) for p in LONG_PATTERNS)
63
+ scores["retrieval_heavy"] = sum(len(re.findall(p, r)) for p in RETRIEVAL_PATTERNS)
64
+ scores["document_drafting"] = sum(len(re.findall(p, r)) for p in DOC_PATTERNS)
65
+ scores["quick_answer"] = 0.5 if len(r.split()) < 10 else 0
66
+ # Check if no strong signal
67
+ max_score = max(scores.values()) if scores else 0
68
+ if max_score == 0:
69
+ return "unknown_ambiguous"
70
+ return max(scores, key=scores.get)
71
 
72
+ def _estimate_difficulty(self, request: str, task_type: str) -> int:
73
+ r = request.lower()
74
+ base = TASK_DIFFICULTY_BASE.get(task_type, 3)
75
+ if any(re.findall(p, r) for p in CRITICAL_PATTERNS):
76
+ base = min(base + 1, 5)
77
+ if any(re.findall(p, r) for p in SIMPLE_PATTERNS):
78
+ base = max(base - 1, 1)
79
+ return base
80
 
81
+ def _needs_tools(self, request: str, task_type: str) -> bool:
82
+ return task_type in ("tool_heavy", "retrieval_heavy", "coding", "research")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def _needs_retrieval(self, request: str, task_type: str) -> bool:
85
+ return task_type in ("retrieval_heavy", "research")
86
+
87
+ def _needs_verifier(self, request: str, task_type: str, risk: str) -> bool:
88
+ return risk in ("high", "critical")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ def _estimate_cost(self, difficulty: int, tools: bool, retrieval: bool, verifier: bool) -> float:
91
+ base_cost = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(difficulty, 1.0)
92
+ if tools: base_cost *= 1.3
93
+ if retrieval: base_cost *= 1.2
94
+ if verifier: base_cost *= 1.1
95
+ return base_cost