narcolepticchicken commited on
Commit
a7c7186
·
verified ·
1 Parent(s): de021eb

Upload aco/classifier.py

Browse files
Files changed (1) hide show
  1. aco/classifier.py +243 -0
aco/classifier.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task Cost Classifier - Module 2.
2
+
3
+ Classifies incoming tasks by expected cost, risk, model strength needed,
4
+ and predicts whether retrieval/verifier is required.
5
+ """
6
+
7
+ import re
8
+ from typing import Dict, List, Tuple, Optional
9
+ from dataclasses import dataclass
10
+
11
+ from .trace_schema import TaskType
12
+ from .config import ACOConfig
13
+
14
+
15
+ @dataclass
16
+ class TaskPrediction:
17
+ task_type: TaskType
18
+ expected_cost: float
19
+ expected_model_tier: int # 1-5
20
+ expected_tools_needed: List[str]
21
+ risk_of_failure: float # 0-1
22
+ retrieval_required: bool
23
+ verifier_required: bool
24
+ expected_latency_ms: float
25
+ confidence: float
26
+
27
+
28
+ class TaskCostClassifier:
29
+ """Classifies agent tasks into cost/risk categories."""
30
+
31
+ # Keywords mapped to task types with base cost estimates
32
+ KEYWORD_MAP: Dict[str, Tuple[TaskType, float, int]] = {
33
+ # quick_answer: low cost, tier 1-2
34
+ "what is": (TaskType.QUICK_ANSWER, 0.001, 1),
35
+ "define": (TaskType.QUICK_ANSWER, 0.001, 1),
36
+ "explain briefly": (TaskType.QUICK_ANSWER, 0.002, 1),
37
+ "summarize": (TaskType.QUICK_ANSWER, 0.005, 2),
38
+ "short answer": (TaskType.QUICK_ANSWER, 0.001, 1),
39
+ # coding: medium-high cost, tier 3-4
40
+ "write code": (TaskType.CODING, 0.05, 3),
41
+ "fix bug": (TaskType.CODING, 0.08, 4),
42
+ "refactor": (TaskType.CODING, 0.03, 3),
43
+ "implement": (TaskType.CODING, 0.05, 3),
44
+ "test": (TaskType.CODING, 0.04, 3),
45
+ "debug": (TaskType.CODING, 0.06, 4),
46
+ "python": (TaskType.CODING, 0.03, 3),
47
+ "javascript": (TaskType.CODING, 0.03, 3),
48
+ "function": (TaskType.CODING, 0.02, 2),
49
+ # research: high cost, tier 3-4
50
+ "research": (TaskType.RESEARCH, 0.15, 4),
51
+ "find sources": (TaskType.RESEARCH, 0.1, 3),
52
+ "literature review": (TaskType.RESEARCH, 0.2, 4),
53
+ "compare": (TaskType.RESEARCH, 0.08, 3),
54
+ "analyze": (TaskType.RESEARCH, 0.1, 3),
55
+ "investigate": (TaskType.RESEARCH, 0.12, 4),
56
+ # document_drafting: medium cost, tier 3
57
+ "draft": (TaskType.DOCUMENT_DRAFTING, 0.05, 3),
58
+ "write a document": (TaskType.DOCUMENT_DRAFTING, 0.06, 3),
59
+ "proposal": (TaskType.DOCUMENT_DRAFTING, 0.08, 3),
60
+ "report": (TaskType.DOCUMENT_DRAFTING, 0.1, 4),
61
+ "email": (TaskType.DOCUMENT_DRAFTING, 0.01, 2),
62
+ # legal_regulated: high cost, tier 4-5
63
+ "contract": (TaskType.LEGAL_REGULATED, 0.15, 5),
64
+ "legal": (TaskType.LEGAL_REGULATED, 0.15, 5),
65
+ "compliance": (TaskType.LEGAL_REGULATED, 0.12, 5),
66
+ "regulatory": (TaskType.LEGAL_REGULATED, 0.12, 5),
67
+ "privacy policy": (TaskType.LEGAL_REGULATED, 0.1, 5),
68
+ "terms of service": (TaskType.LEGAL_REGULATED, 0.1, 5),
69
+ # tool_heavy
70
+ "search for": (TaskType.TOOL_HEAVY, 0.05, 3),
71
+ "look up": (TaskType.TOOL_HEAVY, 0.03, 2),
72
+ "fetch": (TaskType.TOOL_HEAVY, 0.04, 3),
73
+ "api": (TaskType.TOOL_HEAVY, 0.06, 3),
74
+ "database": (TaskType.TOOL_HEAVY, 0.05, 3),
75
+ "scrape": (TaskType.TOOL_HEAVY, 0.04, 3),
76
+ # retrieval_heavy
77
+ "based on the document": (TaskType.RETRIEVAL_HEAVY, 0.08, 3),
78
+ "from my files": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
79
+ "rag": (TaskType.RETRIEVAL_HEAVY, 0.06, 3),
80
+ "retrieve": (TaskType.RETRIEVAL_HEAVY, 0.05, 3),
81
+ # long_horizon
82
+ "plan": (TaskType.LONG_HORIZON, 0.1, 4),
83
+ "project": (TaskType.LONG_HORIZON, 0.15, 4),
84
+ "over the next": (TaskType.LONG_HORIZON, 0.1, 4),
85
+ "multi-step": (TaskType.LONG_HORIZON, 0.08, 4),
86
+ "orchestrate": (TaskType.LONG_HORIZON, 0.12, 4),
87
+ }
88
+
89
+ # Complexity multipliers based on length and structure
90
+ COMPLEXITY_PATTERNS = [
91
+ (r"\b(AND|and)\b.*\b(AND|and)\b.*\b(AND|and)\b", 1.5), # multiple sub-tasks
92
+ (r"\bstep\s+\d+\b", 1.3),
93
+ (r"\d+\+\s*(pages|files|functions|tests)", 1.4),
94
+ (r"\b(entire|whole|all|every)\b", 1.2),
95
+ (r"\b(critical|production|live|deployed)\b", 1.5),
96
+ ]
97
+
98
+ def __init__(self, config: Optional[ACOConfig] = None):
99
+ self.config = config or ACOConfig()
100
+ self.history: List[Dict] = []
101
+
102
+ def classify(self, user_request: str) -> TaskPrediction:
103
+ """Classify a user request into task type, cost, risk, etc."""
104
+ request_lower = user_request.lower()
105
+
106
+ # Find best matching keywords
107
+ matched_types: Dict[TaskType, List[float]] = {}
108
+ for keyword, (task_type, base_cost, tier) in self.KEYWORD_MAP.items():
109
+ if keyword in request_lower:
110
+ matched_types.setdefault(task_type, []).append(base_cost)
111
+
112
+ # Default to unknown if no match
113
+ if not matched_types:
114
+ task_type = TaskType.UNKNOWN_AMBIGUOUS
115
+ base_cost = 0.05
116
+ base_tier = 2
117
+ else:
118
+ # Pick task type with highest cumulative base cost (most specific)
119
+ task_type = max(matched_types.keys(), key=lambda t: sum(matched_types[t]))
120
+ base_cost = max(matched_types[task_type])
121
+ base_tier = self.KEYWORD_MAP[
122
+ max(
123
+ (k for k, (tt, _, _) in self.KEYWORD_MAP.items() if tt == task_type),
124
+ key=lambda k: base_cost if k in request_lower else 0,
125
+ )
126
+ ][2]
127
+
128
+ # Apply complexity multipliers
129
+ complexity_mult = 1.0
130
+ for pattern, mult in self.COMPLEXITY_PATTERNS:
131
+ if re.search(pattern, user_request, re.IGNORECASE):
132
+ complexity_mult = max(complexity_mult, mult)
133
+
134
+ # Length factor
135
+ word_count = len(request_lower.split())
136
+ length_mult = 1.0 + min(word_count / 500, 0.5)
137
+
138
+ expected_cost = base_cost * complexity_mult * length_mult
139
+ expected_tier = min(base_tier + int(complexity_mult > 1.2), 5)
140
+
141
+ # Determine tool needs
142
+ expected_tools = []
143
+ if task_type in (TaskType.RESEARCH, TaskType.TOOL_HEAVY, TaskType.RETRIEVAL_HEAVY):
144
+ expected_tools = ["search", "retrieve", "fetch"]
145
+ elif task_type == TaskType.CODING:
146
+ expected_tools = ["code_execution", "linter", "test_runner"]
147
+ elif task_type == TaskType.LEGAL_REGULATED:
148
+ expected_tools = ["document_retrieval", "compliance_check"]
149
+
150
+ # Risk estimation
151
+ risk = 0.3
152
+ if task_type == TaskType.LEGAL_REGULATED:
153
+ risk = 0.8
154
+ elif task_type == TaskType.LONG_HORIZON:
155
+ risk = 0.6
156
+ elif task_type == TaskType.CODING:
157
+ risk = 0.5
158
+ elif task_type == TaskType.UNKNOWN_AMBIGUOUS:
159
+ risk = 0.7
160
+
161
+ # Adjust risk by complexity
162
+ risk = min(risk * complexity_mult, 1.0)
163
+
164
+ # Verifier required for high-risk or complex tasks
165
+ verifier_required = risk > 0.6 or task_type == TaskType.LEGAL_REGULATED
166
+
167
+ # Retrieval required for research, document, retrieval-heavy
168
+ retrieval_required = task_type in (
169
+ TaskType.RESEARCH,
170
+ TaskType.RETRIEVAL_HEAVY,
171
+ TaskType.DOCUMENT_DRAFTING,
172
+ TaskType.LEGAL_REGULATED,
173
+ )
174
+
175
+ expected_latency = expected_cost * 10000 # rough heuristic: $0.001 ~ 10s
176
+
177
+ return TaskPrediction(
178
+ task_type=task_type,
179
+ expected_cost=expected_cost,
180
+ expected_model_tier=expected_tier,
181
+ expected_tools_needed=expected_tools,
182
+ risk_of_failure=risk,
183
+ retrieval_required=retrieval_required,
184
+ verifier_required=verifier_required,
185
+ expected_latency_ms=expected_latency,
186
+ confidence=0.7 if matched_types else 0.4,
187
+ )
188
+
189
+ def classify_with_history(self, user_request: str, past_traces: List[Dict]) -> TaskPrediction:
190
+ """Classify using historical trace data for this user/task pattern."""
191
+ base = self.classify(user_request)
192
+
193
+ if not past_traces:
194
+ return base
195
+
196
+ # Find similar past requests
197
+ similar = [
198
+ t for t in past_traces
199
+ if self._similarity(user_request, t.get("user_request", "")) > 0.5
200
+ ]
201
+
202
+ if len(similar) >= 3:
203
+ # Adjust predictions based on history
204
+ avg_cost = sum(t.get("total_cost", base.expected_cost) for t in similar) / len(similar)
205
+ success_rate = sum(1 for t in similar if t.get("final_outcome") == "success") / len(similar)
206
+ avg_retries = sum(t.get("total_retries", 0) for t in similar) / len(similar)
207
+
208
+ # If history shows high failure, bump tier and require verifier
209
+ if success_rate < 0.5:
210
+ base = TaskPrediction(
211
+ task_type=base.task_type,
212
+ expected_cost=avg_cost * 1.2,
213
+ expected_model_tier=min(base.expected_model_tier + 1, 5),
214
+ expected_tools_needed=base.expected_tools_needed,
215
+ risk_of_failure=min(base.risk_of_failure * 1.3, 1.0),
216
+ retrieval_required=True,
217
+ verifier_required=True,
218
+ expected_latency_ms=base.expected_latency_ms * 1.2,
219
+ confidence=min(base.confidence + 0.1, 1.0),
220
+ )
221
+ else:
222
+ base = TaskPrediction(
223
+ task_type=base.task_type,
224
+ expected_cost=avg_cost * 0.9, # history suggests we can be cheaper
225
+ expected_model_tier=max(base.expected_model_tier - 1, 1),
226
+ expected_tools_needed=base.expected_tools_needed,
227
+ risk_of_failure=base.risk_of_failure * 0.8,
228
+ retrieval_required=base.retrieval_required,
229
+ verifier_required=base.verifier_required and avg_retries > 1,
230
+ expected_latency_ms=base.expected_latency_ms * 0.9,
231
+ confidence=min(base.confidence + 0.2, 1.0),
232
+ )
233
+
234
+ return base
235
+
236
+ @staticmethod
237
+ def _similarity(a: str, b: str) -> float:
238
+ """Simple Jaccard similarity on words."""
239
+ words_a = set(a.lower().split())
240
+ words_b = set(b.lower().split())
241
+ if not words_a or not words_b:
242
+ return 0.0
243
+ return len(words_a & words_b) / len(words_a | words_b)