narcolepticchicken commited on
Commit
e6100b5
·
verified ·
1 Parent(s): 1b0e9a1

Upload aco/router.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/router.py +132 -231
aco/router.py CHANGED
@@ -1,245 +1,146 @@
1
- """Model Cascade Router - Module 3.
2
-
3
- Routes agent requests through a cascade of models:
4
- tiny local cheap small cloud → medium → frontier → specialist.
5
-
6
- Supports:
7
- A. always frontier
8
- B. static routing
9
- C. prompt-only router
10
- D. trained cost-aware router
11
- E. trained router + verifier fallback
12
- """
13
-
14
- import random
15
- from typing import Dict, List, Optional, Tuple
16
  from dataclasses import dataclass
17
 
18
- from .trace_schema import TaskType, Outcome
19
- from .config import ACOConfig, ModelConfig
20
- from .classifier import TaskPrediction
21
-
22
-
23
  @dataclass
24
  class RoutingDecision:
25
  model_id: str
26
- provider: str
27
  tier: int
28
  confidence: float
29
  reasoning: str
30
- fallback_model_id: Optional[str] = None
31
- use_verifier: bool = False
32
- max_tokens: int = 4096
33
- temperature: float = 0.7
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class ModelCascadeRouter:
37
- """Routes tasks to the cheapest acceptable model."""
38
-
39
- TIER_ORDER = [1, 2, 3, 4, 5] # tiny → cheap → medium → frontier → specialist
40
-
41
- def __init__(self, config: ACOConfig):
42
- self.config = config
43
- self.models_by_tier: Dict[int, List[ModelConfig]] = {t: [] for t in self.TIER_ORDER}
44
- self._build_tier_index()
45
- self.routing_stats: Dict[str, Dict] = {}
46
- self.decision_history: List[Dict] = []
47
-
48
- def _build_tier_index(self):
49
- for name, mc in self.config.models.items():
50
- self.models_by_tier.setdefault(mc.strength_tier, []).append(mc)
51
-
52
- def route(self, task_prediction: TaskPrediction, routing_mode: str = "cascade") -> RoutingDecision:
53
- """Select model based on task prediction and routing policy."""
54
-
55
- if routing_mode == "always_frontier":
56
- return self._route_always_frontier(task_prediction)
57
- elif routing_mode == "static":
58
- return self._route_static(task_prediction)
59
- elif routing_mode == "prompt_only":
60
- return self._route_prompt_only(task_prediction)
61
- elif routing_mode == "learned":
62
- return self._route_learned(task_prediction)
63
- elif routing_mode == "learned_verifier":
64
- return self._route_learned(task_prediction, verifier_fallback=True)
65
- else:
66
- return self._route_cascade(task_prediction)
67
-
68
- def _route_always_frontier(self, prediction: TaskPrediction) -> RoutingDecision:
69
- frontier = self.models_by_tier.get(4, [])
70
- if not frontier:
71
- frontier = self.models_by_tier.get(5, [])
72
- if not frontier:
73
- frontier = self.models_by_tier.get(3, [])
74
- model = frontier[0] if frontier else list(self.config.models.values())[0]
75
- return RoutingDecision(
76
- model_id=model.model_id,
77
- provider=model.provider,
78
- tier=4,
79
- confidence=1.0,
80
- reasoning="Always frontier policy",
81
- max_tokens=min(prediction.expected_cost * 50000, model.max_context),
82
- )
83
-
84
- def _route_static(self, prediction: TaskPrediction) -> RoutingDecision:
85
- # Static mapping: task type -> tier
86
- static_map = {
87
- TaskType.QUICK_ANSWER: 1,
88
- TaskType.UNKNOWN_AMBIGUOUS: 2,
89
- TaskType.TOOL_HEAVY: 2,
90
- TaskType.RETRIEVAL_HEAVY: 2,
91
- TaskType.DOCUMENT_DRAFTING: 3,
92
- TaskType.CODING: 3,
93
- TaskType.RESEARCH: 4,
94
- TaskType.LONG_HORIZON: 4,
95
- TaskType.LEGAL_REGULATED: 5,
96
  }
97
- tier = static_map.get(prediction.task_type, 3)
98
- models = self.models_by_tier.get(tier, self.models_by_tier[3])
99
- model = models[0] if models else list(self.config.models.values())[0]
100
- return RoutingDecision(
101
- model_id=model.model_id,
102
- provider=model.provider,
103
- tier=tier,
104
- confidence=0.6,
105
- reasoning=f"Static routing: {prediction.task_type.value} -> tier {tier}",
106
- fallback_model_id=self._next_tier_model(tier).model_id if tier < 5 else None,
107
- )
108
-
109
- def _route_prompt_only(self, prediction: TaskPrediction) -> RoutingDecision:
110
- """Use prompt heuristics to decide model tier."""
111
- # Heuristic: if risk > 0.7 or expected tier >= 4, use frontier
112
- if prediction.risk_of_failure > 0.7 or prediction.expected_model_tier >= 4:
113
- tier = 4
114
- elif prediction.expected_model_tier <= 2:
115
- tier = max(prediction.expected_model_tier, 1)
116
- else:
117
- # Start cheap, escalate on low confidence
118
- tier = max(prediction.expected_model_tier - 1, 1)
119
-
120
- models = self.models_by_tier.get(tier, self.models_by_tier[3])
121
- model = models[0] if models else list(self.config.models.values())[0]
122
-
123
- fallback = None
124
- if tier < 5 and prediction.risk_of_failure > 0.5:
125
- fallback = self._next_tier_model(tier)
126
- fallback_id = fallback.model_id if fallback else None
127
- else:
128
- fallback_id = None
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  return RoutingDecision(
131
- model_id=model.model_id,
132
- provider=model.provider,
133
  tier=tier,
134
- confidence=1.0 - prediction.risk_of_failure,
135
- reasoning=f"Prompt heuristic: risk={prediction.risk_of_failure:.2f}, expected_tier={prediction.expected_model_tier}",
136
- fallback_model_id=fallback_id,
137
- use_verifier=prediction.verifier_required,
138
- )
139
-
140
- def _route_learned(self, prediction: TaskPrediction, verifier_fallback: bool = False) -> RoutingDecision:
141
- """Learned router with cost-quality tradeoff.
142
-
143
- In a full implementation, this would load a trained classifier.
144
- Here we use a heuristic calibrated from routing_stats.
145
- """
146
- # Check historical success rate per tier for this task type
147
- task_key = prediction.task_type.value
148
- best_tier = None
149
- best_score = -float("inf")
150
-
151
- for tier in self.TIER_ORDER:
152
- stats = self.routing_stats.get(f"{task_key}_tier_{tier}", {})
153
- success_rate = stats.get("success_rate", 0.5)
154
- avg_cost = stats.get("avg_cost", 0.01 * tier)
155
-
156
- # Score = success_weight * success_rate - cost_weight * cost
157
- score = 10 * success_rate - 100 * avg_cost
158
-
159
- # Penalize tiers below expected if risk is high
160
- if tier < prediction.expected_model_tier and prediction.risk_of_failure > 0.5:
161
- score -= 5
162
-
163
- if score > best_score:
164
- best_score = score
165
- best_tier = tier
166
-
167
- # Default to expected tier if no history
168
- if best_tier is None:
169
- best_tier = prediction.expected_model_tier
170
-
171
- models = self.models_by_tier.get(best_tier, self.models_by_tier[3])
172
- model = models[0] if models else list(self.config.models.values())[0]
173
-
174
- # Verifier fallback on uncertain predictions
175
- use_verifier = verifier_fallback and prediction.risk_of_failure > 0.5
176
-
177
- return RoutingDecision(
178
- model_id=model.model_id,
179
- provider=model.provider,
180
- tier=best_tier,
181
- confidence=min(best_score / 10 + 0.5, 1.0),
182
- reasoning=f"Learned router: tier {best_tier} scored {best_score:.3f} for {task_key}",
183
- fallback_model_id=self._next_tier_model(best_tier).model_id if best_tier < 5 else None,
184
- use_verifier=use_verifier,
185
- )
186
-
187
- def _route_cascade(self, prediction: TaskPrediction) -> RoutingDecision:
188
- """FrugalGPT-style cascade: try cheap first, escalate on low confidence."""
189
- start_tier = max(1, prediction.expected_model_tier - 2)
190
-
191
- # Don't start below tier 2 for risky tasks
192
- if prediction.risk_of_failure > 0.6:
193
- start_tier = max(start_tier, 2)
194
-
195
- models = self.models_by_tier.get(start_tier, [])
196
- if not models:
197
- models = self.models_by_tier.get(1, [])
198
- if not models:
199
- models = list(self.config.models.values())
200
-
201
- model = models[0]
202
-
203
- # Determine if we should pre-escalate (for critical tasks)
204
- pre_escalate = prediction.task_type == TaskType.LEGAL_REGULATED
205
-
206
- fallback = None
207
- if not pre_escalate and start_tier < prediction.expected_model_tier:
208
- fallback = self._next_tier_model(start_tier)
209
-
210
- return RoutingDecision(
211
- model_id=model.model_id,
212
- provider=model.provider,
213
- tier=start_tier,
214
- confidence=1.0 - prediction.risk_of_failure,
215
- reasoning=f"Cascade start at tier {start_tier}, expected tier {prediction.expected_model_tier}, risk={prediction.risk_of_failure:.2f}",
216
- fallback_model_id=fallback.model_id if fallback else None,
217
- use_verifier=prediction.verifier_required,
218
  )
219
-
220
- def _next_tier_model(self, current_tier: int) -> Optional[ModelConfig]:
221
- for tier in range(current_tier + 1, 6):
222
- models = self.models_by_tier.get(tier)
223
- if models:
224
- return models[0]
225
- return None
226
-
227
- def update_stats(self, task_type: TaskType, tier: int, cost: float, success: bool) -> None:
228
- key = f"{task_type.value}_tier_{tier}"
229
- stats = self.routing_stats.setdefault(key, {"count": 0, "successes": 0, "total_cost": 0.0})
230
- stats["count"] += 1
231
- if success:
232
- stats["successes"] += 1
233
- stats["total_cost"] += cost
234
- stats["success_rate"] = stats["successes"] / stats["count"]
235
- stats["avg_cost"] = stats["total_cost"] / stats["count"]
236
-
237
- def should_escalate(self, decision: RoutingDecision, step_outcome: Outcome, confidence: float) -> bool:
238
- """Decide whether to escalate to a stronger model after a step."""
239
- if decision.tier >= 5:
240
- return False
241
- if step_outcome == Outcome.FAILURE and confidence < 0.5:
242
- return True
243
- if step_outcome == Outcome.PARTIAL_SUCCESS and decision.tier < 4:
244
- return True
245
- return False
 
1
+ """Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors."""
2
+ import numpy as np
3
+ import pickle, os, json
4
+ from typing import Dict, Tuple, Optional
 
 
 
 
 
 
 
 
 
 
 
5
  from dataclasses import dataclass
6
 
 
 
 
 
 
7
  @dataclass
8
  class RoutingDecision:
9
  model_id: str
 
10
  tier: int
11
  confidence: float
12
  reasoning: str
13
+ cost_estimate: float
14
+ dynamic_difficulty: int
15
+ escalated: bool = False
16
+ downgraded: bool = False
17
+
18
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
19
+ "compile","runtime","segfault","thread","async","class","module"]
20
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
21
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
22
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
23
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
24
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
25
+ CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
26
+ SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
27
+ TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
28
+ "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}
29
+
30
+ TIER_MODELS = {
31
+ 1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0},
32
+ 2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05},
33
+ 3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30},
34
+ 4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00},
35
+ 5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50},
36
+ }
37
 
38
  class ModelCascadeRouter:
39
+ def __init__(self, model_path: str = None, safety_threshold: float = 0.30,
40
+ downgrade_threshold: float = 0.90,
41
+ task_floor: Dict[str,int] = None,
42
+ tier_costs: Dict[int,float] = None):
43
+ self.safety_threshold = safety_threshold
44
+ self.downgrade_threshold = downgrade_threshold
45
+ self.task_floor = task_floor or {
46
+ "legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
47
+ "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
48
+ "tool_heavy":2,"retrieval_heavy":2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
+ self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
51
+ self.tier_clfs = None
52
+ self.tier_calibs = None
53
+ self.feat_keys = None
54
+ self._load_model(model_path)
55
+
56
+ def _load_model(self, model_path: str = None):
57
+ if model_path and os.path.exists(model_path):
58
+ try:
59
+ bundle = pickle.load(open(model_path, "rb"))
60
+ self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()}
61
+ self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()}
62
+ self.feat_keys = bundle.get("feat_keys", None)
63
+ except Exception as e:
64
+ print(f"[ACO] Warning: Could not load router model: {e}")
65
+
66
+ def estimate_difficulty(self, request: str, task_type: str) -> int:
67
+ r = request.lower()
68
+ base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
69
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3)
70
+ if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
71
+ if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
72
+ return base
73
+
74
+ def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray:
75
+ r = request.lower()
76
+ feats = {
77
+ "req_len": len(request), "num_words": len(request.split()),
78
+ "has_code": int(any(k in r for k in CODE_KW)),
79
+ "n_code": sum(1 for k in CODE_KW if k in r),
80
+ "has_legal": int(any(k in r for k in LEGAL_KW)),
81
+ "n_legal": sum(1 for k in LEGAL_KW if k in r),
82
+ "has_research": int(any(k in r for k in RESEARCH_KW)),
83
+ "n_research": sum(1 for k in RESEARCH_KW if k in r),
84
+ "has_tool": int(any(k in r for k in TOOL_KW)),
85
+ "n_tool": sum(1 for k in TOOL_KW if k in r),
86
+ "has_long": int(any(k in r for k in LONG_KW)),
87
+ "has_math": int(any(k in r for k in MATH_KW)),
88
+ "tt_idx": TT2IDX.get(task_type, 8),
89
+ "difficulty": difficulty,
90
+ }
91
+ for tt in TT2IDX:
92
+ feats[f"tt_{tt}"] = int(task_type == tt)
93
+ if self.feat_keys:
94
+ return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1)
95
+ return np.zeros((1, 23), dtype=np.float32)
96
+
97
+ def _get_psuccess(self, x: np.ndarray, tier: int) -> float:
98
+ if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs:
99
+ try:
100
+ p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1]
101
+ return float(self.tier_calibs[tier].transform([p_raw])[0])
102
+ except: pass
103
+ # Fallback heuristic probability
104
+ strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
105
+ diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3
106
+ return strengths.get(tier, 0.80) ** (diff_feat * 0.6)
107
+
108
+ def route(self, request: str, task_type: str, difficulty: int = None,
109
+ prediction: dict = None) -> RoutingDecision:
110
+ if difficulty is None:
111
+ difficulty = self.estimate_difficulty(request, task_type)
112
+ base = min(difficulty + 1, 5)
113
+ floor = self.task_floor.get(task_type, 2)
114
+ base = max(base, floor)
115
+ x = self._extract_features(request, task_type, difficulty)
116
+ tier = base
117
+ ps = self._get_psuccess(x, tier)
118
+ escalated = False
119
+ downgraded = False
120
+ # Safety net
121
+ if ps < self.safety_threshold and tier < 5:
122
+ tier += 1
123
+ ps = self._get_psuccess(x, tier)
124
+ escalated = True
125
+ # Cost saver
126
+ if tier > floor and not escalated and tier == base:
127
+ cheaper = tier - 1
128
+ pc = self._get_psuccess(x, cheaper)
129
+ if pc >= self.downgrade_threshold and cheaper >= floor:
130
+ tier = cheaper
131
+ ps = pc
132
+ downgraded = True
133
+ model_info = TIER_MODELS.get(tier, TIER_MODELS[4])
134
+ reasoning_parts = [f"base_tier={base}"]
135
+ if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})")
136
+ if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})")
137
  return RoutingDecision(
138
+ model_id=model_info["model_id"],
 
139
  tier=tier,
140
+ confidence=ps,
141
+ reasoning="; ".join(reasoning_parts),
142
+ cost_estimate=self.tier_costs.get(tier, 1.0),
143
+ dynamic_difficulty=difficulty,
144
+ escalated=escalated,
145
+ downgraded=downgraded,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )