"""Trained Production Router - Replaces heuristic routing. Architecture: difficulty-first + ML confirmation + safety floors. Usage: from aco.learned_router import TrainedRouter router = TrainedRouter.from_pretrained("narcolepticchicken/agent-cost-optimizer") tier, confidence = router.predict("Write a Python function", "coding", difficulty=3) """ import json import os import pickle from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from collections import defaultdict try: import numpy as np import xgboost as xgb HAS_ML = True except ImportError: HAS_ML = False TASK_TYPES = ["quick_answer","coding","research","document_drafting", "legal_regulated","tool_heavy","retrieval_heavy", "long_horizon","unknown_ambiguous"] TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)} CODE_KW = ["python","javascript","code","function","bug","debug","refactor", "implement","test","compile","runtime","class","module","async","thread"] LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"] RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"] TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"] MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"] # Default safety floors per task type TASK_FLOOR = { "legal_regulated":4,"long_horizon":3,"research":3,"coding":3, "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2, "tool_heavy":2,"retrieval_heavy":2, } class TrainedRouter: """Production trained router: difficulty-first + ML confirmation + safety floors.""" def __init__(self, tier_clfs: Dict, feat_keys: List[str], tier_config: Dict, escalation_threshold: float = 0.55): self.tier_clfs = tier_clfs self.feat_keys = feat_keys self.tier_config = tier_config self.tier_cost = {int(k):v for k,v in tier_config["tier_cost"].items()} self.task_floor = tier_config.get("task_floor", TASK_FLOOR) self.escalation_threshold = escalation_threshold self._trained = True def extract_features(self, request: str, task_type: str, difficulty: int = 3) -> Dict: r = request.lower() f = {"req_len":len(request),"num_words":len(request.split()), "has_code":int(any(k in r for k in CODE_KW)), "n_code":sum(1 for k in CODE_KW if k in r), "has_legal":int(any(k in r for k in LEGAL_KW)), "n_legal":sum(1 for k in LEGAL_KW if k in r), "has_research":int(any(k in r for k in RESEARCH_KW)), "n_research":sum(1 for k in RESEARCH_KW if k in r), "has_tool":int(any(k in r for k in TOOL_KW)), "n_tool":sum(1 for k in TOOL_KW if k in r), "has_long":int(any(k in r for k in LONG_KW)), "has_math":int(any(k in r for k in MATH_KW)), "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty} for tt in TASK_TYPES: f[f"tt_{tt}"] = int(task_type == tt) return f def _feats_to_vec(self, feats: Dict): import numpy as np return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32) def predict(self, request: str, task_type: str, difficulty: int = 3, escalation_threshold: Optional[float] = None) -> Tuple[int, float]: """Predict optimal tier using difficulty-first + ML confirmation. Returns: (tier, confidence) """ threshold = escalation_threshold or self.escalation_threshold # Step 1: difficulty -> base_tier base_tier = min(difficulty + 1, 5) # Step 2: apply safety floor floor = self.task_floor.get(task_type, 2) base_tier = max(base_tier, floor) if not HAS_ML or not self._trained: return base_tier, 0.6 # Step 3: ML confirmation feats = self.extract_features(request, task_type, difficulty) x = self._feats_to_vec(feats).reshape(1, -1) p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1] confidence = p_success # Step 4: escalate if P(success) too low while p_success < threshold and base_tier < 5: base_tier += 1 p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1] confidence = p_success return base_tier, float(confidence) @classmethod def from_pretrained(cls, repo_id: str, escalation_threshold: float = 0.55, cache_dir: Optional[str] = None): """Load trained router from HuggingFace Hub.""" from huggingface_hub import hf_hub_download bundle_path = hf_hub_download( repo_id=repo_id, filename="router_models/router_bundle.pkl", cache_dir=cache_dir, ) with open(bundle_path, "rb") as f: import pickle bundle = pickle.load(f) return cls( tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()}, feat_keys=bundle["feat_keys"], tier_config=bundle["tier_config"], escalation_threshold=escalation_threshold, ) @classmethod def from_local(cls, model_dir: str, escalation_threshold: float = 0.55): """Load from local directory.""" bundle_path = os.path.join(model_dir, "router_bundle.pkl") with open(bundle_path, "rb") as f: import pickle bundle = pickle.load(f) return cls( tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()}, feat_keys=bundle["feat_keys"], tier_config=bundle["tier_config"], escalation_threshold=escalation_threshold, )