narcolepticchicken commited on
Commit
2b522a0
·
verified ·
1 Parent(s): a7c7186

Upload aco/router.py

Browse files
Files changed (1) hide show
  1. aco/router.py +245 -0
aco/router.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model Cascade Router - Module 3.
2
+
3
+ Routes agent requests through a cascade of models:
4
+ tiny local → cheap small cloud → medium → frontier → specialist.
5
+
6
+ Supports:
7
+ A. always frontier
8
+ B. static routing
9
+ C. prompt-only router
10
+ D. trained cost-aware router
11
+ E. trained router + verifier fallback
12
+ """
13
+
14
+ import random
15
+ from typing import Dict, List, Optional, Tuple
16
+ from dataclasses import dataclass
17
+
18
+ from .trace_schema import TaskType, Outcome
19
+ from .config import ACOConfig, ModelConfig
20
+ from .classifier import TaskPrediction
21
+
22
+
23
+ @dataclass
24
+ class RoutingDecision:
25
+ model_id: str
26
+ provider: str
27
+ tier: int
28
+ confidence: float
29
+ reasoning: str
30
+ fallback_model_id: Optional[str] = None
31
+ use_verifier: bool = False
32
+ max_tokens: int = 4096
33
+ temperature: float = 0.7
34
+
35
+
36
+ class ModelCascadeRouter:
37
+ """Routes tasks to the cheapest acceptable model."""
38
+
39
+ TIER_ORDER = [1, 2, 3, 4, 5] # tiny → cheap → medium → frontier → specialist
40
+
41
+ def __init__(self, config: ACOConfig):
42
+ self.config = config
43
+ self.models_by_tier: Dict[int, List[ModelConfig]] = {t: [] for t in self.TIER_ORDER}
44
+ self._build_tier_index()
45
+ self.routing_stats: Dict[str, Dict] = {}
46
+ self.decision_history: List[Dict] = []
47
+
48
+ def _build_tier_index(self):
49
+ for name, mc in self.config.models.items():
50
+ self.models_by_tier.setdefault(mc.strength_tier, []).append(mc)
51
+
52
+ def route(self, task_prediction: TaskPrediction, routing_mode: str = "cascade") -> RoutingDecision:
53
+ """Select model based on task prediction and routing policy."""
54
+
55
+ if routing_mode == "always_frontier":
56
+ return self._route_always_frontier(task_prediction)
57
+ elif routing_mode == "static":
58
+ return self._route_static(task_prediction)
59
+ elif routing_mode == "prompt_only":
60
+ return self._route_prompt_only(task_prediction)
61
+ elif routing_mode == "learned":
62
+ return self._route_learned(task_prediction)
63
+ elif routing_mode == "learned_verifier":
64
+ return self._route_learned(task_prediction, verifier_fallback=True)
65
+ else:
66
+ return self._route_cascade(task_prediction)
67
+
68
+ def _route_always_frontier(self, prediction: TaskPrediction) -> RoutingDecision:
69
+ frontier = self.models_by_tier.get(4, [])
70
+ if not frontier:
71
+ frontier = self.models_by_tier.get(5, [])
72
+ if not frontier:
73
+ frontier = self.models_by_tier.get(3, [])
74
+ model = frontier[0] if frontier else list(self.config.models.values())[0]
75
+ return RoutingDecision(
76
+ model_id=model.model_id,
77
+ provider=model.provider,
78
+ tier=4,
79
+ confidence=1.0,
80
+ reasoning="Always frontier policy",
81
+ max_tokens=min(prediction.expected_cost * 50000, model.max_context),
82
+ )
83
+
84
+ def _route_static(self, prediction: TaskPrediction) -> RoutingDecision:
85
+ # Static mapping: task type -> tier
86
+ static_map = {
87
+ TaskType.QUICK_ANSWER: 1,
88
+ TaskType.UNKNOWN_AMBIGUOUS: 2,
89
+ TaskType.TOOL_HEAVY: 2,
90
+ TaskType.RETRIEVAL_HEAVY: 2,
91
+ TaskType.DOCUMENT_DRAFTING: 3,
92
+ TaskType.CODING: 3,
93
+ TaskType.RESEARCH: 4,
94
+ TaskType.LONG_HORIZON: 4,
95
+ TaskType.LEGAL_REGULATED: 5,
96
+ }
97
+ tier = static_map.get(prediction.task_type, 3)
98
+ models = self.models_by_tier.get(tier, self.models_by_tier[3])
99
+ model = models[0] if models else list(self.config.models.values())[0]
100
+ return RoutingDecision(
101
+ model_id=model.model_id,
102
+ provider=model.provider,
103
+ tier=tier,
104
+ confidence=0.6,
105
+ reasoning=f"Static routing: {prediction.task_type.value} -> tier {tier}",
106
+ fallback_model_id=self._next_tier_model(tier).model_id if tier < 5 else None,
107
+ )
108
+
109
+ def _route_prompt_only(self, prediction: TaskPrediction) -> RoutingDecision:
110
+ """Use prompt heuristics to decide model tier."""
111
+ # Heuristic: if risk > 0.7 or expected tier >= 4, use frontier
112
+ if prediction.risk_of_failure > 0.7 or prediction.expected_model_tier >= 4:
113
+ tier = 4
114
+ elif prediction.expected_model_tier <= 2:
115
+ tier = max(prediction.expected_model_tier, 1)
116
+ else:
117
+ # Start cheap, escalate on low confidence
118
+ tier = max(prediction.expected_model_tier - 1, 1)
119
+
120
+ models = self.models_by_tier.get(tier, self.models_by_tier[3])
121
+ model = models[0] if models else list(self.config.models.values())[0]
122
+
123
+ fallback = None
124
+ if tier < 5 and prediction.risk_of_failure > 0.5:
125
+ fallback = self._next_tier_model(tier)
126
+ fallback_id = fallback.model_id if fallback else None
127
+ else:
128
+ fallback_id = None
129
+
130
+ return RoutingDecision(
131
+ model_id=model.model_id,
132
+ provider=model.provider,
133
+ tier=tier,
134
+ confidence=1.0 - prediction.risk_of_failure,
135
+ reasoning=f"Prompt heuristic: risk={prediction.risk_of_failure:.2f}, expected_tier={prediction.expected_model_tier}",
136
+ fallback_model_id=fallback_id,
137
+ use_verifier=prediction.verifier_required,
138
+ )
139
+
140
+ def _route_learned(self, prediction: TaskPrediction, verifier_fallback: bool = False) -> RoutingDecision:
141
+ """Learned router with cost-quality tradeoff.
142
+
143
+ In a full implementation, this would load a trained classifier.
144
+ Here we use a heuristic calibrated from routing_stats.
145
+ """
146
+ # Check historical success rate per tier for this task type
147
+ task_key = prediction.task_type.value
148
+ best_tier = None
149
+ best_score = -float("inf")
150
+
151
+ for tier in self.TIER_ORDER:
152
+ stats = self.routing_stats.get(f"{task_key}_tier_{tier}", {})
153
+ success_rate = stats.get("success_rate", 0.5)
154
+ avg_cost = stats.get("avg_cost", 0.01 * tier)
155
+
156
+ # Score = success_weight * success_rate - cost_weight * cost
157
+ score = 10 * success_rate - 100 * avg_cost
158
+
159
+ # Penalize tiers below expected if risk is high
160
+ if tier < prediction.expected_model_tier and prediction.risk_of_failure > 0.5:
161
+ score -= 5
162
+
163
+ if score > best_score:
164
+ best_score = score
165
+ best_tier = tier
166
+
167
+ # Default to expected tier if no history
168
+ if best_tier is None:
169
+ best_tier = prediction.expected_model_tier
170
+
171
+ models = self.models_by_tier.get(best_tier, self.models_by_tier[3])
172
+ model = models[0] if models else list(self.config.models.values())[0]
173
+
174
+ # Verifier fallback on uncertain predictions
175
+ use_verifier = verifier_fallback and prediction.risk_of_failure > 0.5
176
+
177
+ return RoutingDecision(
178
+ model_id=model.model_id,
179
+ provider=model.provider,
180
+ tier=best_tier,
181
+ confidence=min(best_score / 10 + 0.5, 1.0),
182
+ reasoning=f"Learned router: tier {best_tier} scored {best_score:.3f} for {task_key}",
183
+ fallback_model_id=self._next_tier_model(best_tier).model_id if best_tier < 5 else None,
184
+ use_verifier=use_verifier,
185
+ )
186
+
187
+ def _route_cascade(self, prediction: TaskPrediction) -> RoutingDecision:
188
+ """FrugalGPT-style cascade: try cheap first, escalate on low confidence."""
189
+ start_tier = max(1, prediction.expected_model_tier - 2)
190
+
191
+ # Don't start below tier 2 for risky tasks
192
+ if prediction.risk_of_failure > 0.6:
193
+ start_tier = max(start_tier, 2)
194
+
195
+ models = self.models_by_tier.get(start_tier, [])
196
+ if not models:
197
+ models = self.models_by_tier.get(1, [])
198
+ if not models:
199
+ models = list(self.config.models.values())
200
+
201
+ model = models[0]
202
+
203
+ # Determine if we should pre-escalate (for critical tasks)
204
+ pre_escalate = prediction.task_type == TaskType.LEGAL_REGULATED
205
+
206
+ fallback = None
207
+ if not pre_escalate and start_tier < prediction.expected_model_tier:
208
+ fallback = self._next_tier_model(start_tier)
209
+
210
+ return RoutingDecision(
211
+ model_id=model.model_id,
212
+ provider=model.provider,
213
+ tier=start_tier,
214
+ confidence=1.0 - prediction.risk_of_failure,
215
+ reasoning=f"Cascade start at tier {start_tier}, expected tier {prediction.expected_model_tier}, risk={prediction.risk_of_failure:.2f}",
216
+ fallback_model_id=fallback.model_id if fallback else None,
217
+ use_verifier=prediction.verifier_required,
218
+ )
219
+
220
+ def _next_tier_model(self, current_tier: int) -> Optional[ModelConfig]:
221
+ for tier in range(current_tier + 1, 6):
222
+ models = self.models_by_tier.get(tier)
223
+ if models:
224
+ return models[0]
225
+ return None
226
+
227
+ def update_stats(self, task_type: TaskType, tier: int, cost: float, success: bool) -> None:
228
+ key = f"{task_type.value}_tier_{tier}"
229
+ stats = self.routing_stats.setdefault(key, {"count": 0, "successes": 0, "total_cost": 0.0})
230
+ stats["count"] += 1
231
+ if success:
232
+ stats["successes"] += 1
233
+ stats["total_cost"] += cost
234
+ stats["success_rate"] = stats["successes"] / stats["count"]
235
+ stats["avg_cost"] = stats["total_cost"] / stats["count"]
236
+
237
+ def should_escalate(self, decision: RoutingDecision, step_outcome: Outcome, confidence: float) -> bool:
238
+ """Decide whether to escalate to a stronger model after a step."""
239
+ if decision.tier >= 5:
240
+ return False
241
+ if step_outcome == Outcome.FAILURE and confidence < 0.5:
242
+ return True
243
+ if step_outcome == Outcome.PARTIAL_SUCCESS and decision.tier < 4:
244
+ return True
245
+ return False