File size: 7,294 Bytes
e6100b5 2b522a0 e6100b5 2b522a0 e6100b5 2b522a0 e6100b5 2b522a0 e6100b5 2b522a0 e6100b5 2b522a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors."""
import numpy as np
import pickle, os, json
from typing import Dict, Tuple, Optional
from dataclasses import dataclass
@dataclass
class RoutingDecision:
model_id: str
tier: int
confidence: float
reasoning: str
cost_estimate: float
dynamic_difficulty: int
escalated: bool = False
downgraded: bool = False
CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
"compile","runtime","segfault","thread","async","class","module"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
"legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}
TIER_MODELS = {
1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0},
2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05},
3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30},
4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00},
5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50},
}
class ModelCascadeRouter:
def __init__(self, model_path: str = None, safety_threshold: float = 0.30,
downgrade_threshold: float = 0.90,
task_floor: Dict[str,int] = None,
tier_costs: Dict[int,float] = None):
self.safety_threshold = safety_threshold
self.downgrade_threshold = downgrade_threshold
self.task_floor = task_floor or {
"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
"unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
"tool_heavy":2,"retrieval_heavy":2,
}
self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
self.tier_clfs = None
self.tier_calibs = None
self.feat_keys = None
self._load_model(model_path)
def _load_model(self, model_path: str = None):
if model_path and os.path.exists(model_path):
try:
bundle = pickle.load(open(model_path, "rb"))
self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()}
self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()}
self.feat_keys = bundle.get("feat_keys", None)
except Exception as e:
print(f"[ACO] Warning: Could not load router model: {e}")
def estimate_difficulty(self, request: str, task_type: str) -> int:
r = request.lower()
base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
"research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3)
if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
return base
def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray:
r = request.lower()
feats = {
"req_len": len(request), "num_words": len(request.split()),
"has_code": int(any(k in r for k in CODE_KW)),
"n_code": sum(1 for k in CODE_KW if k in r),
"has_legal": int(any(k in r for k in LEGAL_KW)),
"n_legal": sum(1 for k in LEGAL_KW if k in r),
"has_research": int(any(k in r for k in RESEARCH_KW)),
"n_research": sum(1 for k in RESEARCH_KW if k in r),
"has_tool": int(any(k in r for k in TOOL_KW)),
"n_tool": sum(1 for k in TOOL_KW if k in r),
"has_long": int(any(k in r for k in LONG_KW)),
"has_math": int(any(k in r for k in MATH_KW)),
"tt_idx": TT2IDX.get(task_type, 8),
"difficulty": difficulty,
}
for tt in TT2IDX:
feats[f"tt_{tt}"] = int(task_type == tt)
if self.feat_keys:
return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1)
return np.zeros((1, 23), dtype=np.float32)
def _get_psuccess(self, x: np.ndarray, tier: int) -> float:
if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs:
try:
p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1]
return float(self.tier_calibs[tier].transform([p_raw])[0])
except: pass
# Fallback heuristic probability
strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3
return strengths.get(tier, 0.80) ** (diff_feat * 0.6)
def route(self, request: str, task_type: str, difficulty: int = None,
prediction: dict = None) -> RoutingDecision:
if difficulty is None:
difficulty = self.estimate_difficulty(request, task_type)
base = min(difficulty + 1, 5)
floor = self.task_floor.get(task_type, 2)
base = max(base, floor)
x = self._extract_features(request, task_type, difficulty)
tier = base
ps = self._get_psuccess(x, tier)
escalated = False
downgraded = False
# Safety net
if ps < self.safety_threshold and tier < 5:
tier += 1
ps = self._get_psuccess(x, tier)
escalated = True
# Cost saver
if tier > floor and not escalated and tier == base:
cheaper = tier - 1
pc = self._get_psuccess(x, cheaper)
if pc >= self.downgrade_threshold and cheaper >= floor:
tier = cheaper
ps = pc
downgraded = True
model_info = TIER_MODELS.get(tier, TIER_MODELS[4])
reasoning_parts = [f"base_tier={base}"]
if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})")
if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})")
return RoutingDecision(
model_id=model_info["model_id"],
tier=tier,
confidence=ps,
reasoning="; ".join(reasoning_parts),
cost_estimate=self.tier_costs.get(tier, 1.0),
dynamic_difficulty=difficulty,
escalated=escalated,
downgraded=downgraded,
)
|