| """Trained Production Router - Replaces heuristic routing. |
| |
| Architecture: difficulty-first + ML confirmation + safety floors. |
| |
| Usage: |
| from aco.learned_router import TrainedRouter |
| |
| router = TrainedRouter.from_pretrained("narcolepticchicken/agent-cost-optimizer") |
| tier, confidence = router.predict("Write a Python function", "coding", difficulty=3) |
| """ |
|
|
| import json |
| import os |
| import pickle |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
| from collections import defaultdict |
|
|
| try: |
| import numpy as np |
| import xgboost as xgb |
| HAS_ML = True |
| except ImportError: |
| HAS_ML = False |
|
|
|
|
| TASK_TYPES = ["quick_answer","coding","research","document_drafting", |
| "legal_regulated","tool_heavy","retrieval_heavy", |
| "long_horizon","unknown_ambiguous"] |
| TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)} |
|
|
| CODE_KW = ["python","javascript","code","function","bug","debug","refactor", |
| "implement","test","compile","runtime","class","module","async","thread"] |
| LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"] |
| RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"] |
| TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] |
| LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"] |
| MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"] |
|
|
| |
| TASK_FLOOR = { |
| "legal_regulated":4,"long_horizon":3,"research":3,"coding":3, |
| "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2, |
| "tool_heavy":2,"retrieval_heavy":2, |
| } |
|
|
|
|
| class TrainedRouter: |
| """Production trained router: difficulty-first + ML confirmation + safety floors.""" |
|
|
| def __init__(self, tier_clfs: Dict, feat_keys: List[str], |
| tier_config: Dict, escalation_threshold: float = 0.55): |
| self.tier_clfs = tier_clfs |
| self.feat_keys = feat_keys |
| self.tier_config = tier_config |
| self.tier_cost = {int(k):v for k,v in tier_config["tier_cost"].items()} |
| self.task_floor = tier_config.get("task_floor", TASK_FLOOR) |
| self.escalation_threshold = escalation_threshold |
| self._trained = True |
|
|
| def extract_features(self, request: str, task_type: str, difficulty: int = 3) -> Dict: |
| r = request.lower() |
| f = {"req_len":len(request),"num_words":len(request.split()), |
| "has_code":int(any(k in r for k in CODE_KW)), |
| "n_code":sum(1 for k in CODE_KW if k in r), |
| "has_legal":int(any(k in r for k in LEGAL_KW)), |
| "n_legal":sum(1 for k in LEGAL_KW if k in r), |
| "has_research":int(any(k in r for k in RESEARCH_KW)), |
| "n_research":sum(1 for k in RESEARCH_KW if k in r), |
| "has_tool":int(any(k in r for k in TOOL_KW)), |
| "n_tool":sum(1 for k in TOOL_KW if k in r), |
| "has_long":int(any(k in r for k in LONG_KW)), |
| "has_math":int(any(k in r for k in MATH_KW)), |
| "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty} |
| for tt in TASK_TYPES: |
| f[f"tt_{tt}"] = int(task_type == tt) |
| return f |
|
|
| def _feats_to_vec(self, feats: Dict): |
| import numpy as np |
| return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32) |
|
|
| def predict(self, request: str, task_type: str, difficulty: int = 3, |
| escalation_threshold: Optional[float] = None) -> Tuple[int, float]: |
| """Predict optimal tier using difficulty-first + ML confirmation. |
| |
| Returns: (tier, confidence) |
| """ |
| threshold = escalation_threshold or self.escalation_threshold |
| |
| |
| base_tier = min(difficulty + 1, 5) |
| |
| |
| floor = self.task_floor.get(task_type, 2) |
| base_tier = max(base_tier, floor) |
| |
| if not HAS_ML or not self._trained: |
| return base_tier, 0.6 |
|
|
| |
| feats = self.extract_features(request, task_type, difficulty) |
| x = self._feats_to_vec(feats).reshape(1, -1) |
| |
| p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1] |
| confidence = p_success |
| |
| |
| while p_success < threshold and base_tier < 5: |
| base_tier += 1 |
| p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1] |
| confidence = p_success |
| |
| return base_tier, float(confidence) |
|
|
| @classmethod |
| def from_pretrained(cls, repo_id: str, escalation_threshold: float = 0.55, |
| cache_dir: Optional[str] = None): |
| """Load trained router from HuggingFace Hub.""" |
| from huggingface_hub import hf_hub_download |
| |
| bundle_path = hf_hub_download( |
| repo_id=repo_id, filename="router_models/router_bundle.pkl", |
| cache_dir=cache_dir, |
| ) |
| |
| with open(bundle_path, "rb") as f: |
| import pickle |
| bundle = pickle.load(f) |
| |
| return cls( |
| tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()}, |
| feat_keys=bundle["feat_keys"], |
| tier_config=bundle["tier_config"], |
| escalation_threshold=escalation_threshold, |
| ) |
|
|
| @classmethod |
| def from_local(cls, model_dir: str, escalation_threshold: float = 0.55): |
| """Load from local directory.""" |
| bundle_path = os.path.join(model_dir, "router_bundle.pkl") |
| with open(bundle_path, "rb") as f: |
| import pickle |
| bundle = pickle.load(f) |
| |
| return cls( |
| tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()}, |
| feat_keys=bundle["feat_keys"], |
| tier_config=bundle["tier_config"], |
| escalation_threshold=escalation_threshold, |
| ) |
|
|