narcolepticchicken commited on
Commit
284d6c8
·
verified ·
1 Parent(s): c727785

Upload aco/learned_router.py

Browse files
Files changed (1) hide show
  1. aco/learned_router.py +228 -0
aco/learned_router.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Learned Router - Trainable cost-aware model router.
2
+
3
+ Trains a simple classifier on synthetic traces to predict optimal model tier.
4
+ Can be swapped into ModelCascadeRouter._route_learned().
5
+ """
6
+
7
+ import json
8
+ import pickle
9
+ from typing import Dict, List, Optional, Tuple
10
+ from dataclasses import dataclass
11
+ from collections import defaultdict
12
+
13
+
14
+ @dataclass
15
+ class RouterFeatures:
16
+ task_type_encoded: int
17
+ request_length: int
18
+ num_words: int
19
+ contains_code_keywords: bool
20
+ contains_legal_keywords: bool
21
+ contains_research_keywords: bool
22
+ contains_tool_keywords: bool
23
+ expected_difficulty: int
24
+ has_prior_failures: bool
25
+ prior_success_rate: float
26
+
27
+
28
+ class LearnedRouter:
29
+ """Simple learned router that can be trained on trace data."""
30
+
31
+ def __init__(self):
32
+ self.weights: Dict[str, float] = {}
33
+ self.task_type_bias: Dict[str, float] = {}
34
+ self.trained = False
35
+
36
+ def _extract_features(self, user_request: str, task_type: str, history: List[Dict]) -> Dict[str, float]:
37
+ """Extract features from request and history."""
38
+ req_lower = user_request.lower()
39
+
40
+ features = {
41
+ "request_length": len(user_request),
42
+ "num_words": len(user_request.split()),
43
+ "has_code": any(kw in req_lower for kw in ["python", "javascript", "code", "function", "bug", "debug", "refactor", "implement", "test"]),
44
+ "has_legal": any(kw in req_lower for kw in ["contract", "legal", "compliance", "gdpr", "privacy", "policy", "regulatory"]),
45
+ "has_research": any(kw in req_lower for kw in ["research", "find sources", "literature", "investigate", "compare", "analyze"]),
46
+ "has_tools": any(kw in req_lower for kw in ["search", "fetch", "retrieve", "query", "api", "database", "scrape"]),
47
+ "has_long_horizon": any(kw in req_lower for kw in ["plan", "project", "roadmap", "orchestrate", "multi-step"]),
48
+ }
49
+
50
+ # Historical success rate for this task type
51
+ similar = [t for t in history if t.get("task_type") == task_type]
52
+ if similar:
53
+ success_count = sum(1 for t in similar if t.get("final_outcome") == "success")
54
+ features["prior_success_rate"] = success_count / len(similar)
55
+ features["has_prior_failures"] = any(t.get("final_outcome") == "failure" for t in similar[-5:])
56
+ else:
57
+ features["prior_success_rate"] = 0.5
58
+ features["has_prior_failures"] = False
59
+
60
+ return features
61
+
62
+ def _score_tier(self, features: Dict[str, float], tier: int) -> float:
63
+ """Score a tier given features. Higher is better."""
64
+ if not self.trained:
65
+ # Heuristic scoring before training
66
+ base_score = {1: 0.3, 2: 0.5, 3: 0.7, 4: 0.85, 5: 0.9}.get(tier, 0.5)
67
+
68
+ # Adjust by task complexity signals
69
+ if features["has_legal"] and tier < 4:
70
+ base_score -= 0.4
71
+ if features["has_code"] and tier < 3:
72
+ base_score -= 0.2
73
+ if features["has_research"] and tier < 3:
74
+ base_score -= 0.15
75
+ if features["has_long_horizon"] and tier < 3:
76
+ base_score -= 0.2
77
+
78
+ # Adjust by history
79
+ if features["has_prior_failures"] and tier < 3:
80
+ base_score -= 0.3
81
+ if features["prior_success_rate"] > 0.8 and tier > 2:
82
+ base_score += 0.1
83
+
84
+ return base_score
85
+
86
+ # Trained scoring
87
+ score = self.task_type_bias.get(str(tier), 0.5)
88
+ for feat_name, feat_val in features.items():
89
+ weight_key = f"{feat_name}_tier_{tier}"
90
+ score += self.weights.get(weight_key, 0.0) * (1.0 if feat_val else 0.0)
91
+ return score
92
+
93
+ def predict_tier(self, user_request: str, task_type: str, history: List[Dict] = None) -> Tuple[int, float]:
94
+ """Predict optimal model tier and confidence."""
95
+ history = history or []
96
+ features = self._extract_features(user_request, task_type, history)
97
+
98
+ best_tier = 3
99
+ best_score = -float("inf")
100
+
101
+ for tier in [1, 2, 3, 4, 5]:
102
+ score = self._score_tier(features, tier)
103
+ if score > best_score:
104
+ best_score = score
105
+ best_tier = tier
106
+
107
+ confidence = min(best_score, 1.0)
108
+ return best_tier, confidence
109
+
110
+ def train(self, traces: List[Dict]) -> None:
111
+ """Train router on trace data using simple frequency-based weight learning."""
112
+ # For each trace, determine optimal tier from outcome
113
+ tier_counts = defaultdict(lambda: defaultdict(int))
114
+
115
+ for trace in traces:
116
+ task_type = trace.get("task_type", "unknown")
117
+ outcome = trace.get("final_outcome", "failure")
118
+ difficulty = trace.get("metadata", {}).get("difficulty", 3)
119
+ actual_tier = trace.get("metadata", {}).get("actual_tier", 3)
120
+
121
+ # Optimal tier is the minimum tier that would succeed
122
+ if outcome == "success":
123
+ optimal = actual_tier # This tier succeeded
124
+ else:
125
+ optimal = min(actual_tier + 1, 5) # Need higher tier
126
+
127
+ # Extract features
128
+ req = trace.get("user_request", "")
129
+ features = self._extract_features(req, task_type, [])
130
+
131
+ # Count successes per feature+tier combination
132
+ for feat_name, feat_val in features.items():
133
+ if isinstance(feat_val, bool) and feat_val:
134
+ tier_counts[feat_name][optimal] += 1
135
+
136
+ tier_counts["_overall"][optimal] += 1
137
+
138
+ # Compute bias per tier
139
+ total = sum(tier_counts["_overall"].values())
140
+ if total > 0:
141
+ for tier in [1, 2, 3, 4, 5]:
142
+ count = tier_counts["_overall"].get(tier, 0)
143
+ self.task_type_bias[str(tier)] = count / total
144
+
145
+ # Compute feature weights
146
+ for feat_name, tier_dist in tier_counts.items():
147
+ if feat_name == "_overall":
148
+ continue
149
+ total_feat = sum(tier_dist.values())
150
+ if total_feat > 0:
151
+ for tier in [1, 2, 3, 4, 5]:
152
+ weight_key = f"{feat_name}_tier_{tier}"
153
+ # Positive if this tier is common when feature is present
154
+ self.weights[weight_key] = (tier_dist.get(tier, 0) / total_feat) - self.task_type_bias.get(str(tier), 0.1)
155
+
156
+ self.trained = True
157
+ print(f"Trained router on {len(traces)} traces. Weights: {len(self.weights)}.")
158
+
159
+ def save(self, path: str) -> None:
160
+ with open(path, "wb") as f:
161
+ pickle.dump({"weights": self.weights, "bias": self.task_type_bias, "trained": self.trained}, f)
162
+
163
+ def load(self, path: str) -> None:
164
+ with open(path, "rb") as f:
165
+ data = pickle.load(f)
166
+ self.weights = data["weights"]
167
+ self.task_type_bias = data["bias"]
168
+ self.trained = data["trained"]
169
+
170
+
171
+ class RouterDatasetBuilder:
172
+ """Builds training datasets from traces for learned router training."""
173
+
174
+ @staticmethod
175
+ def from_traces(traces: List[Dict]) -> List[Dict]:
176
+ """Convert traces to (features, optimal_tier) training examples."""
177
+ examples = []
178
+ for trace in traces:
179
+ difficulty = trace.get("metadata", {}).get("difficulty", 3)
180
+ actual_tier = trace.get("metadata", {}).get("actual_tier", 3)
181
+ outcome = trace.get("final_outcome", "failure")
182
+
183
+ # Optimal tier
184
+ if outcome == "success":
185
+ optimal = actual_tier
186
+ else:
187
+ optimal = min(actual_tier + 1, 5)
188
+
189
+ # Simple feature extraction
190
+ req = trace.get("user_request", "").lower()
191
+ features = {
192
+ "length": len(req),
193
+ "has_code": any(kw in req for kw in ["python", "code", "function", "bug", "debug"]),
194
+ "has_legal": any(kw in req for kw in ["contract", "legal", "compliance", "gdpr"]),
195
+ "has_research": any(kw in req for kw in ["research", "find sources", "literature"]),
196
+ "task_type": trace.get("task_type", "unknown"),
197
+ "difficulty": difficulty,
198
+ }
199
+
200
+ examples.append({"features": features, "optimal_tier": optimal, "outcome": outcome})
201
+
202
+ return examples
203
+
204
+ @staticmethod
205
+ def compute_oracle_savings(traces: List[Dict]) -> Dict[str, float]:
206
+ """Compute what an oracle router (perfect tier selection) would save."""
207
+ total_cost = 0.0
208
+ oracle_cost = 0.0
209
+
210
+ for trace in traces:
211
+ cost = trace.get("total_cost", 0.0)
212
+ difficulty = trace.get("metadata", {}).get("difficulty", 3)
213
+
214
+ # Oracle uses cheapest tier that can handle difficulty
215
+ oracle_tier = max(1, difficulty - 1) # Slightly optimistic
216
+ oracle_mult = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(oracle_tier, 0.75)
217
+ actual_mult = {1: 0.05, 2: 0.15, 3: 0.75, 4: 1.0, 5: 1.5}.get(
218
+ trace.get("metadata", {}).get("actual_tier", 3), 0.75)
219
+
220
+ oracle_cost += cost * (oracle_mult / actual_mult) if actual_mult > 0 else cost
221
+ total_cost += cost
222
+
223
+ return {
224
+ "total_cost": total_cost,
225
+ "oracle_cost": oracle_cost,
226
+ "savings": (total_cost - oracle_cost) / max(total_cost, 1),
227
+ "per_trace_savings": (total_cost - oracle_cost) / len(traces) if traces else 0,
228
+ }