File size: 6,031 Bytes
17a2ae0
284d6c8
17a2ae0
 
 
 
 
 
 
284d6c8
 
 
17a2ae0
284d6c8
 
 
 
 
17a2ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284d6c8
17a2ae0
 
284d6c8
17a2ae0
 
 
284d6c8
17a2ae0
 
284d6c8
17a2ae0
 
 
284d6c8
17a2ae0
 
284d6c8
17a2ae0
 
 
 
 
284d6c8
17a2ae0
 
 
 
 
 
 
284d6c8
17a2ae0
 
 
 
284d6c8
17a2ae0
 
 
284d6c8
17a2ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
284d6c8
17a2ae0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Trained Production Router - Replaces heuristic routing.

Architecture: difficulty-first + ML confirmation + safety floors.

Usage:
    from aco.learned_router import TrainedRouter
    
    router = TrainedRouter.from_pretrained("narcolepticchicken/agent-cost-optimizer")
    tier, confidence = router.predict("Write a Python function", "coding", difficulty=3)
"""

import json
import os
import pickle
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict

try:
    import numpy as np
    import xgboost as xgb
    HAS_ML = True
except ImportError:
    HAS_ML = False


TASK_TYPES = ["quick_answer","coding","research","document_drafting",
              "legal_regulated","tool_heavy","retrieval_heavy",
              "long_horizon","unknown_ambiguous"]
TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)}

CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
           "implement","test","compile","runtime","class","module","async","thread"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"]

# Default safety floors per task type
TASK_FLOOR = {
    "legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
    "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
    "tool_heavy":2,"retrieval_heavy":2,
}


class TrainedRouter:
    """Production trained router: difficulty-first + ML confirmation + safety floors."""

    def __init__(self, tier_clfs: Dict, feat_keys: List[str], 
                 tier_config: Dict, escalation_threshold: float = 0.55):
        self.tier_clfs = tier_clfs
        self.feat_keys = feat_keys
        self.tier_config = tier_config
        self.tier_cost = {int(k):v for k,v in tier_config["tier_cost"].items()}
        self.task_floor = tier_config.get("task_floor", TASK_FLOOR)
        self.escalation_threshold = escalation_threshold
        self._trained = True

    def extract_features(self, request: str, task_type: str, difficulty: int = 3) -> Dict:
        r = request.lower()
        f = {"req_len":len(request),"num_words":len(request.split()),
            "has_code":int(any(k in r for k in CODE_KW)),
            "n_code":sum(1 for k in CODE_KW if k in r),
            "has_legal":int(any(k in r for k in LEGAL_KW)),
            "n_legal":sum(1 for k in LEGAL_KW if k in r),
            "has_research":int(any(k in r for k in RESEARCH_KW)),
            "n_research":sum(1 for k in RESEARCH_KW if k in r),
            "has_tool":int(any(k in r for k in TOOL_KW)),
            "n_tool":sum(1 for k in TOOL_KW if k in r),
            "has_long":int(any(k in r for k in LONG_KW)),
            "has_math":int(any(k in r for k in MATH_KW)),
            "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty}
        for tt in TASK_TYPES:
            f[f"tt_{tt}"] = int(task_type == tt)
        return f

    def _feats_to_vec(self, feats: Dict):
        import numpy as np
        return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32)

    def predict(self, request: str, task_type: str, difficulty: int = 3,
                escalation_threshold: Optional[float] = None) -> Tuple[int, float]:
        """Predict optimal tier using difficulty-first + ML confirmation.

        Returns: (tier, confidence)
        """
        threshold = escalation_threshold or self.escalation_threshold
        
        # Step 1: difficulty -> base_tier
        base_tier = min(difficulty + 1, 5)
        
        # Step 2: apply safety floor
        floor = self.task_floor.get(task_type, 2)
        base_tier = max(base_tier, floor)
        
        if not HAS_ML or not self._trained:
            return base_tier, 0.6

        # Step 3: ML confirmation
        feats = self.extract_features(request, task_type, difficulty)
        x = self._feats_to_vec(feats).reshape(1, -1)
        
        p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1]
        confidence = p_success
        
        # Step 4: escalate if P(success) too low
        while p_success < threshold and base_tier < 5:
            base_tier += 1
            p_success = self.tier_clfs[base_tier].predict_proba(x)[0, 1]
            confidence = p_success
        
        return base_tier, float(confidence)

    @classmethod
    def from_pretrained(cls, repo_id: str, escalation_threshold: float = 0.55,
                        cache_dir: Optional[str] = None):
        """Load trained router from HuggingFace Hub."""
        from huggingface_hub import hf_hub_download
        
        bundle_path = hf_hub_download(
            repo_id=repo_id, filename="router_models/router_bundle.pkl",
            cache_dir=cache_dir,
        )
        
        with open(bundle_path, "rb") as f:
            import pickle
            bundle = pickle.load(f)
        
        return cls(
            tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()},
            feat_keys=bundle["feat_keys"],
            tier_config=bundle["tier_config"],
            escalation_threshold=escalation_threshold,
        )

    @classmethod
    def from_local(cls, model_dir: str, escalation_threshold: float = 0.55):
        """Load from local directory."""
        bundle_path = os.path.join(model_dir, "router_bundle.pkl")
        with open(bundle_path, "rb") as f:
            import pickle
            bundle = pickle.load(f)
        
        return cls(
            tier_clfs={int(k): v for k, v in bundle["tier_clfs"].items()},
            feat_keys=bundle["feat_keys"],
            tier_config=bundle["tier_config"],
            escalation_threshold=escalation_threshold,
        )