"""Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors.""" import numpy as np import pickle, os, json from typing import Dict, Tuple, Optional from dataclasses import dataclass @dataclass class RoutingDecision: model_id: str tier: int confidence: float reasoning: str cost_estimate: float dynamic_difficulty: int escalated: bool = False downgraded: bool = False CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test", "compile","runtime","segfault","thread","async","class","module"] LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"] RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"] TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"] MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"] CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"] SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"] TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3, "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8} TIER_MODELS = { 1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0}, 2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05}, 3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30}, 4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00}, 5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50}, } class ModelCascadeRouter: def __init__(self, model_path: str = None, safety_threshold: float = 0.30, downgrade_threshold: float = 0.90, task_floor: Dict[str,int] = None, tier_costs: Dict[int,float] = None): self.safety_threshold = safety_threshold self.downgrade_threshold = downgrade_threshold self.task_floor = task_floor or { "legal_regulated":4,"long_horizon":3,"research":3,"coding":3, "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2, "tool_heavy":2,"retrieval_heavy":2, } self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} self.tier_clfs = None self.tier_calibs = None self.feat_keys = None self._load_model(model_path) def _load_model(self, model_path: str = None): if model_path and os.path.exists(model_path): try: bundle = pickle.load(open(model_path, "rb")) self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()} self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()} self.feat_keys = bundle.get("feat_keys", None) except Exception as e: print(f"[ACO] Warning: Could not load router model: {e}") def estimate_difficulty(self, request: str, task_type: str) -> int: r = request.lower() base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3) if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5) if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1) return base def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray: r = request.lower() feats = { "req_len": len(request), "num_words": len(request.split()), "has_code": int(any(k in r for k in CODE_KW)), "n_code": sum(1 for k in CODE_KW if k in r), "has_legal": int(any(k in r for k in LEGAL_KW)), "n_legal": sum(1 for k in LEGAL_KW if k in r), "has_research": int(any(k in r for k in RESEARCH_KW)), "n_research": sum(1 for k in RESEARCH_KW if k in r), "has_tool": int(any(k in r for k in TOOL_KW)), "n_tool": sum(1 for k in TOOL_KW if k in r), "has_long": int(any(k in r for k in LONG_KW)), "has_math": int(any(k in r for k in MATH_KW)), "tt_idx": TT2IDX.get(task_type, 8), "difficulty": difficulty, } for tt in TT2IDX: feats[f"tt_{tt}"] = int(task_type == tt) if self.feat_keys: return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1) return np.zeros((1, 23), dtype=np.float32) def _get_psuccess(self, x: np.ndarray, tier: int) -> float: if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs: try: p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1] return float(self.tier_calibs[tier].transform([p_raw])[0]) except: pass # Fallback heuristic probability strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97} diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3 return strengths.get(tier, 0.80) ** (diff_feat * 0.6) def route(self, request: str, task_type: str, difficulty: int = None, prediction: dict = None) -> RoutingDecision: if difficulty is None: difficulty = self.estimate_difficulty(request, task_type) base = min(difficulty + 1, 5) floor = self.task_floor.get(task_type, 2) base = max(base, floor) x = self._extract_features(request, task_type, difficulty) tier = base ps = self._get_psuccess(x, tier) escalated = False downgraded = False # Safety net if ps < self.safety_threshold and tier < 5: tier += 1 ps = self._get_psuccess(x, tier) escalated = True # Cost saver if tier > floor and not escalated and tier == base: cheaper = tier - 1 pc = self._get_psuccess(x, cheaper) if pc >= self.downgrade_threshold and cheaper >= floor: tier = cheaper ps = pc downgraded = True model_info = TIER_MODELS.get(tier, TIER_MODELS[4]) reasoning_parts = [f"base_tier={base}"] if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})") if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})") return RoutingDecision( model_id=model_info["model_id"], tier=tier, confidence=ps, reasoning="; ".join(reasoning_parts), cost_estimate=self.tier_costs.get(tier, 1.0), dynamic_difficulty=difficulty, escalated=escalated, downgraded=downgraded, )