| """Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors.""" |
| import numpy as np |
| import pickle, os, json |
| from typing import Dict, Tuple, Optional |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class RoutingDecision: |
| model_id: str |
| tier: int |
| confidence: float |
| reasoning: str |
| cost_estimate: float |
| dynamic_difficulty: int |
| escalated: bool = False |
| downgraded: bool = False |
|
|
| CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test", |
| "compile","runtime","segfault","thread","async","class","module"] |
| LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"] |
| RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"] |
| TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] |
| LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"] |
| MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"] |
| CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"] |
| SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"] |
| TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3, |
| "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8} |
|
|
| TIER_MODELS = { |
| 1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0}, |
| 2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05}, |
| 3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30}, |
| 4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00}, |
| 5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50}, |
| } |
|
|
| class ModelCascadeRouter: |
| def __init__(self, model_path: str = None, safety_threshold: float = 0.30, |
| downgrade_threshold: float = 0.90, |
| task_floor: Dict[str,int] = None, |
| tier_costs: Dict[int,float] = None): |
| self.safety_threshold = safety_threshold |
| self.downgrade_threshold = downgrade_threshold |
| self.task_floor = task_floor or { |
| "legal_regulated":4,"long_horizon":3,"research":3,"coding":3, |
| "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2, |
| "tool_heavy":2,"retrieval_heavy":2, |
| } |
| self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} |
| self.tier_clfs = None |
| self.tier_calibs = None |
| self.feat_keys = None |
| self._load_model(model_path) |
|
|
| def _load_model(self, model_path: str = None): |
| if model_path and os.path.exists(model_path): |
| try: |
| bundle = pickle.load(open(model_path, "rb")) |
| self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()} |
| self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()} |
| self.feat_keys = bundle.get("feat_keys", None) |
| except Exception as e: |
| print(f"[ACO] Warning: Could not load router model: {e}") |
|
|
| def estimate_difficulty(self, request: str, task_type: str) -> int: |
| r = request.lower() |
| base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, |
| "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3) |
| if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5) |
| if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1) |
| return base |
|
|
| def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray: |
| r = request.lower() |
| feats = { |
| "req_len": len(request), "num_words": len(request.split()), |
| "has_code": int(any(k in r for k in CODE_KW)), |
| "n_code": sum(1 for k in CODE_KW if k in r), |
| "has_legal": int(any(k in r for k in LEGAL_KW)), |
| "n_legal": sum(1 for k in LEGAL_KW if k in r), |
| "has_research": int(any(k in r for k in RESEARCH_KW)), |
| "n_research": sum(1 for k in RESEARCH_KW if k in r), |
| "has_tool": int(any(k in r for k in TOOL_KW)), |
| "n_tool": sum(1 for k in TOOL_KW if k in r), |
| "has_long": int(any(k in r for k in LONG_KW)), |
| "has_math": int(any(k in r for k in MATH_KW)), |
| "tt_idx": TT2IDX.get(task_type, 8), |
| "difficulty": difficulty, |
| } |
| for tt in TT2IDX: |
| feats[f"tt_{tt}"] = int(task_type == tt) |
| if self.feat_keys: |
| return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1) |
| return np.zeros((1, 23), dtype=np.float32) |
|
|
| def _get_psuccess(self, x: np.ndarray, tier: int) -> float: |
| if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs: |
| try: |
| p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1] |
| return float(self.tier_calibs[tier].transform([p_raw])[0]) |
| except: pass |
| |
| strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97} |
| diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3 |
| return strengths.get(tier, 0.80) ** (diff_feat * 0.6) |
|
|
| def route(self, request: str, task_type: str, difficulty: int = None, |
| prediction: dict = None) -> RoutingDecision: |
| if difficulty is None: |
| difficulty = self.estimate_difficulty(request, task_type) |
| base = min(difficulty + 1, 5) |
| floor = self.task_floor.get(task_type, 2) |
| base = max(base, floor) |
| x = self._extract_features(request, task_type, difficulty) |
| tier = base |
| ps = self._get_psuccess(x, tier) |
| escalated = False |
| downgraded = False |
| |
| if ps < self.safety_threshold and tier < 5: |
| tier += 1 |
| ps = self._get_psuccess(x, tier) |
| escalated = True |
| |
| if tier > floor and not escalated and tier == base: |
| cheaper = tier - 1 |
| pc = self._get_psuccess(x, cheaper) |
| if pc >= self.downgrade_threshold and cheaper >= floor: |
| tier = cheaper |
| ps = pc |
| downgraded = True |
| model_info = TIER_MODELS.get(tier, TIER_MODELS[4]) |
| reasoning_parts = [f"base_tier={base}"] |
| if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})") |
| if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})") |
| return RoutingDecision( |
| model_id=model_info["model_id"], |
| tier=tier, |
| confidence=ps, |
| reasoning="; ".join(reasoning_parts), |
| cost_estimate=self.tier_costs.get(tier, 1.0), |
| dynamic_difficulty=difficulty, |
| escalated=escalated, |
| downgraded=downgraded, |
| ) |
|
|