narcolepticchicken's picture
Upload aco/router.py with huggingface_hub
e6100b5 verified
"""Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors."""
import numpy as np
import pickle, os, json
from typing import Dict, Tuple, Optional
from dataclasses import dataclass
@dataclass
class RoutingDecision:
model_id: str
tier: int
confidence: float
reasoning: str
cost_estimate: float
dynamic_difficulty: int
escalated: bool = False
downgraded: bool = False
CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
"compile","runtime","segfault","thread","async","class","module"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
"legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}
TIER_MODELS = {
1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0},
2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05},
3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30},
4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00},
5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50},
}
class ModelCascadeRouter:
def __init__(self, model_path: str = None, safety_threshold: float = 0.30,
downgrade_threshold: float = 0.90,
task_floor: Dict[str,int] = None,
tier_costs: Dict[int,float] = None):
self.safety_threshold = safety_threshold
self.downgrade_threshold = downgrade_threshold
self.task_floor = task_floor or {
"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
"unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
"tool_heavy":2,"retrieval_heavy":2,
}
self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
self.tier_clfs = None
self.tier_calibs = None
self.feat_keys = None
self._load_model(model_path)
def _load_model(self, model_path: str = None):
if model_path and os.path.exists(model_path):
try:
bundle = pickle.load(open(model_path, "rb"))
self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()}
self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()}
self.feat_keys = bundle.get("feat_keys", None)
except Exception as e:
print(f"[ACO] Warning: Could not load router model: {e}")
def estimate_difficulty(self, request: str, task_type: str) -> int:
r = request.lower()
base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
"research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3)
if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
return base
def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray:
r = request.lower()
feats = {
"req_len": len(request), "num_words": len(request.split()),
"has_code": int(any(k in r for k in CODE_KW)),
"n_code": sum(1 for k in CODE_KW if k in r),
"has_legal": int(any(k in r for k in LEGAL_KW)),
"n_legal": sum(1 for k in LEGAL_KW if k in r),
"has_research": int(any(k in r for k in RESEARCH_KW)),
"n_research": sum(1 for k in RESEARCH_KW if k in r),
"has_tool": int(any(k in r for k in TOOL_KW)),
"n_tool": sum(1 for k in TOOL_KW if k in r),
"has_long": int(any(k in r for k in LONG_KW)),
"has_math": int(any(k in r for k in MATH_KW)),
"tt_idx": TT2IDX.get(task_type, 8),
"difficulty": difficulty,
}
for tt in TT2IDX:
feats[f"tt_{tt}"] = int(task_type == tt)
if self.feat_keys:
return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1)
return np.zeros((1, 23), dtype=np.float32)
def _get_psuccess(self, x: np.ndarray, tier: int) -> float:
if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs:
try:
p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1]
return float(self.tier_calibs[tier].transform([p_raw])[0])
except: pass
# Fallback heuristic probability
strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3
return strengths.get(tier, 0.80) ** (diff_feat * 0.6)
def route(self, request: str, task_type: str, difficulty: int = None,
prediction: dict = None) -> RoutingDecision:
if difficulty is None:
difficulty = self.estimate_difficulty(request, task_type)
base = min(difficulty + 1, 5)
floor = self.task_floor.get(task_type, 2)
base = max(base, floor)
x = self._extract_features(request, task_type, difficulty)
tier = base
ps = self._get_psuccess(x, tier)
escalated = False
downgraded = False
# Safety net
if ps < self.safety_threshold and tier < 5:
tier += 1
ps = self._get_psuccess(x, tier)
escalated = True
# Cost saver
if tier > floor and not escalated and tier == base:
cheaper = tier - 1
pc = self._get_psuccess(x, cheaper)
if pc >= self.downgrade_threshold and cheaper >= floor:
tier = cheaper
ps = pc
downgraded = True
model_info = TIER_MODELS.get(tier, TIER_MODELS[4])
reasoning_parts = [f"base_tier={base}"]
if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})")
if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})")
return RoutingDecision(
model_id=model_info["model_id"],
tier=tier,
confidence=ps,
reasoning="; ".join(reasoning_parts),
cost_estimate=self.tier_costs.get(tier, 1.0),
dynamic_difficulty=difficulty,
escalated=escalated,
downgraded=downgraded,
)