File size: 7,294 Bytes
e6100b5
 
 
 
2b522a0
 
 
 
 
 
 
 
e6100b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b522a0
 
e6100b5
 
 
 
 
 
 
 
 
 
2b522a0
e6100b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b522a0
e6100b5
2b522a0
e6100b5
 
 
 
 
 
2b522a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Model Cascade Router: Dynamic difficulty + ML confirmation + safety floors."""
import numpy as np
import pickle, os, json
from typing import Dict, Tuple, Optional
from dataclasses import dataclass

@dataclass
class RoutingDecision:
    model_id: str
    tier: int
    confidence: float
    reasoning: str
    cost_estimate: float
    dynamic_difficulty: int
    escalated: bool = False
    downgraded: bool = False

CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
           "compile","runtime","segfault","thread","async","class","module"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
           "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}

TIER_MODELS = {
    1: {"model_id": "tiny-local-3b", "provider": "local", "cost_per_1k": 0.0},
    2: {"model_id": "cheap-cloud-8b", "provider": "cloud", "cost_per_1k": 0.05},
    3: {"model_id": "medium-70b", "provider": "cloud", "cost_per_1k": 0.30},
    4: {"model_id": "frontier-latest", "provider": "cloud", "cost_per_1k": 1.00},
    5: {"model_id": "specialist-expert", "provider": "cloud", "cost_per_1k": 1.50},
}

class ModelCascadeRouter:
    def __init__(self, model_path: str = None, safety_threshold: float = 0.30,
                 downgrade_threshold: float = 0.90,
                 task_floor: Dict[str,int] = None,
                 tier_costs: Dict[int,float] = None):
        self.safety_threshold = safety_threshold
        self.downgrade_threshold = downgrade_threshold
        self.task_floor = task_floor or {
            "legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
            "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
            "tool_heavy":2,"retrieval_heavy":2,
        }
        self.tier_costs = tier_costs or {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
        self.tier_clfs = None
        self.tier_calibs = None
        self.feat_keys = None
        self._load_model(model_path)

    def _load_model(self, model_path: str = None):
        if model_path and os.path.exists(model_path):
            try:
                bundle = pickle.load(open(model_path, "rb"))
                self.tier_clfs = {int(k):v for k,v in bundle.get("tier_clfs",{}).items()}
                self.tier_calibs = {int(k):v for k,v in bundle.get("tier_calibrators",{}).items()}
                self.feat_keys = bundle.get("feat_keys", None)
            except Exception as e:
                print(f"[ACO] Warning: Could not load router model: {e}")

    def estimate_difficulty(self, request: str, task_type: str) -> int:
        r = request.lower()
        base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
                "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}.get(task_type,3)
        if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
        if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
        return base

    def _extract_features(self, request: str, task_type: str, difficulty: int) -> np.ndarray:
        r = request.lower()
        feats = {
            "req_len": len(request), "num_words": len(request.split()),
            "has_code": int(any(k in r for k in CODE_KW)),
            "n_code": sum(1 for k in CODE_KW if k in r),
            "has_legal": int(any(k in r for k in LEGAL_KW)),
            "n_legal": sum(1 for k in LEGAL_KW if k in r),
            "has_research": int(any(k in r for k in RESEARCH_KW)),
            "n_research": sum(1 for k in RESEARCH_KW if k in r),
            "has_tool": int(any(k in r for k in TOOL_KW)),
            "n_tool": sum(1 for k in TOOL_KW if k in r),
            "has_long": int(any(k in r for k in LONG_KW)),
            "has_math": int(any(k in r for k in MATH_KW)),
            "tt_idx": TT2IDX.get(task_type, 8),
            "difficulty": difficulty,
        }
        for tt in TT2IDX:
            feats[f"tt_{tt}"] = int(task_type == tt)
        if self.feat_keys:
            return np.array([float(feats.get(k, 0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1, -1)
        return np.zeros((1, 23), dtype=np.float32)

    def _get_psuccess(self, x: np.ndarray, tier: int) -> float:
        if self.tier_clfs and tier in self.tier_clfs and self.tier_calibs and tier in self.tier_calibs:
            try:
                p_raw = self.tier_clfs[tier].predict_proba(x)[0, 1]
                return float(self.tier_calibs[tier].transform([p_raw])[0])
            except: pass
        # Fallback heuristic probability
        strengths = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
        diff_feat = float(x[0, self.feat_keys.index("difficulty")]) if self.feat_keys and "difficulty" in self.feat_keys else 3
        return strengths.get(tier, 0.80) ** (diff_feat * 0.6)

    def route(self, request: str, task_type: str, difficulty: int = None,
              prediction: dict = None) -> RoutingDecision:
        if difficulty is None:
            difficulty = self.estimate_difficulty(request, task_type)
        base = min(difficulty + 1, 5)
        floor = self.task_floor.get(task_type, 2)
        base = max(base, floor)
        x = self._extract_features(request, task_type, difficulty)
        tier = base
        ps = self._get_psuccess(x, tier)
        escalated = False
        downgraded = False
        # Safety net
        if ps < self.safety_threshold and tier < 5:
            tier += 1
            ps = self._get_psuccess(x, tier)
            escalated = True
        # Cost saver
        if tier > floor and not escalated and tier == base:
            cheaper = tier - 1
            pc = self._get_psuccess(x, cheaper)
            if pc >= self.downgrade_threshold and cheaper >= floor:
                tier = cheaper
                ps = pc
                downgraded = True
        model_info = TIER_MODELS.get(tier, TIER_MODELS[4])
        reasoning_parts = [f"base_tier={base}"]
        if escalated: reasoning_parts.append(f"escalated(P(success@{base})<{self.safety_threshold})")
        if downgraded: reasoning_parts.append(f"downgraded(P(success@{cheaper})>={self.downgrade_threshold})")
        return RoutingDecision(
            model_id=model_info["model_id"],
            tier=tier,
            confidence=ps,
            reasoning="; ".join(reasoning_parts),
            cost_estimate=self.tier_costs.get(tier, 1.0),
            dynamic_difficulty=difficulty,
            escalated=escalated,
            downgraded=downgraded,
        )