narcolepticchicken commited on
Commit
8191bbb
·
verified ·
1 Parent(s): b23607e

Upload aco/router_v10.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/router_v10.py +151 -0
aco/router_v10.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ACO v10 Router: Trained on REAL SWE-Router execution data.
2
+
3
+ Key difference from v8: Uses XGBoost models trained on 500 real
4
+ SWE-bench tasks across 8 models, not synthetic data.
5
+
6
+ Routes based on problem-statement features → per-tier P(success) →
7
+ optimal tier selection. Supports cascade + feedback escalation.
8
+ """
9
+ import numpy as np
10
+ import pickle, os, json
11
+ from typing import Dict, Optional, Tuple
12
+ from dataclasses import dataclass
13
+
14
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
15
+ "implement","test","compile","runtime","segfault","thread","async","class",
16
+ "module","import","error","traceback"]
17
+ CRITICAL_KW = ["critical","production","urgent","emergency","live","deployed",
18
+ "safety","security"]
19
+ SIMPLE_KW = ["typo","simple","quick","brief","minor","small","easy","trivial","just"]
20
+ RESEARCH_KW = ["research","investigate","compare","analyze","survey","paper"]
21
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
22
+ LONG_KW = ["plan","project","roadmap","orchestrate","migrate","pipeline","deploy","architecture"]
23
+
24
+ FEAT_KEYS = sorted([
25
+ 'req_len','num_words','has_code','n_code','has_legal','has_research',
26
+ 'has_tool','has_critical','has_simple','has_long','has_math',
27
+ 'has_error_msg','has_file_path','n_lines','has_version','has_add',
28
+ 'has_fix','has_change','has_remove','has_test','has_doc',
29
+ 'has_see_also','has_steps_to_reproduce',
30
+ ])
31
+
32
+ TIER_TO_MODEL = {
33
+ 1: 'deepseek-v4-flash', 2: 'gpt-5-mini',
34
+ 3: 'gemini-2.5-pro', 4: 'claude-opus-4.7', 5: 'gemini-3-pro',
35
+ }
36
+
37
+ TIER_COST = {1:0.01, 2:0.05, 3:0.15, 4:0.30, 5:0.50}
38
+
39
+ @dataclass
40
+ class V10RoutingDecision:
41
+ tier: int
42
+ model: str
43
+ confidence: float
44
+ cost_estimate: float
45
+ tier_probs: Dict[int, float]
46
+ escalated: bool = False
47
+
48
+ class V10Router:
49
+ def __init__(self, model_path: str = None, success_threshold: float = 0.5):
50
+ self.success_threshold = success_threshold
51
+ self.tier_clfs = None
52
+ self.tier_calibs = None
53
+ self.opt_clf = None
54
+ self.feat_keys = FEAT_KEYS
55
+ if model_path and os.path.exists(model_path):
56
+ self._load(model_path)
57
+
58
+ def _load(self, path):
59
+ bundle = pickle.load(open(path, 'rb'))
60
+ self.tier_clfs = {int(k):v for k,v in bundle.get('tier_clfs',{}).items()}
61
+ self.tier_calibs = {int(k):v for k,v in bundle.get('tier_calibrators',{}).items()}
62
+ self.opt_clf = bundle.get('opt_clf', None)
63
+ self.feat_keys = bundle.get('feat_keys', FEAT_KEYS)
64
+
65
+ def _extract(self, text: str) -> np.ndarray:
66
+ r = text.lower()
67
+ feats = {
68
+ 'req_len': len(text), 'num_words': len(text.split()),
69
+ 'has_code': int(any(k in r for k in CODE_KW)),
70
+ 'n_code': sum(1 for k in CODE_KW if k in r),
71
+ 'has_legal': int(any(k in r for k in ["contract","legal","compliance"])),
72
+ 'has_research': int(any(k in r for k in RESEARCH_KW)),
73
+ 'has_tool': int(any(k in r for k in TOOL_KW)),
74
+ 'has_critical': int(any(k in r for k in CRITICAL_KW)),
75
+ 'has_simple': int(any(k in r for k in SIMPLE_KW)),
76
+ 'has_long': int(any(k in r for k in LONG_KW)),
77
+ 'has_math': int(any(k in r for k in ["calculate","compute","solve","equation"])),
78
+ 'has_error_msg': int('error' in r or 'traceback' in r or 'exception' in r),
79
+ 'has_file_path': int('/' in r),
80
+ 'n_lines': text.count('\n') + 1,
81
+ 'has_version': int('version' in r or 'update' in r),
82
+ 'has_add': int('add' in r or 'new' in r or 'create' in r),
83
+ 'has_fix': int('fix' in r or 'bug' in r or 'issue' in r),
84
+ 'has_change': int('change' in r or 'modify' in r),
85
+ 'has_remove': int('remove' in r or 'delete' in r),
86
+ 'has_test': int('test' in r or 'spec' in r),
87
+ 'has_doc': int('doc' in r or 'readme' in r),
88
+ 'has_see_also': int('see also' in r or 'related' in r),
89
+ 'has_steps_to_reproduce': int('reproduce' in r or 'steps' in r),
90
+ }
91
+ return np.array([float(feats.get(k,0.0)) for k in self.feat_keys], dtype=np.float32).reshape(1,-1)
92
+
93
+ def route_cascade(self, text: str) -> V10RoutingDecision:
94
+ """Route to cheapest tier with P(success) >= threshold."""
95
+ x = self._extract(text)
96
+ tier_probs = {}
97
+ if self.tier_clfs:
98
+ for t in range(1, 6):
99
+ if t in self.tier_clfs:
100
+ p_raw = self.tier_clfs[t].predict_proba(x)[0,1]
101
+ p_cal = float(self.tier_calibs[t].transform([p_raw])[0])
102
+ tier_probs[t] = p_cal
103
+ else:
104
+ tier_probs[t] = 0.5
105
+ else:
106
+ tier_probs = {1:0.67,2:0.72,3:0.50,4:0.84,5:0.70}
107
+
108
+ # Find cheapest tier above threshold
109
+ selected_tier = 5
110
+ for t in range(1, 6):
111
+ if tier_probs.get(t, 0) >= self.success_threshold:
112
+ selected_tier = t
113
+ break
114
+
115
+ model = TIER_TO_MODEL.get(selected_tier, 'claude-opus-4.7')
116
+ return V10RoutingDecision(
117
+ tier=selected_tier, model=model,
118
+ confidence=tier_probs.get(selected_tier, 0.5),
119
+ cost_estimate=TIER_COST.get(selected_tier, 0.30),
120
+ tier_probs=tier_probs,
121
+ )
122
+
123
+ def route_direct(self, text: str) -> V10RoutingDecision:
124
+ """Predict optimal tier directly."""
125
+ x = self._extract(text)
126
+ if self.opt_clf:
127
+ tier = int(self.opt_clf.predict(x)[0]) + 1
128
+ else:
129
+ tier = 4 # fallback
130
+ model = TIER_TO_MODEL.get(tier, 'claude-opus-4.7')
131
+ return V10RoutingDecision(
132
+ tier=tier, model=model,
133
+ confidence=0.8, cost_estimate=TIER_COST.get(tier, 0.30),
134
+ tier_probs={},
135
+ )
136
+
137
+ def route_with_feedback(self, text: str, initial_success: bool = True) -> V10RoutingDecision:
138
+ """Route with feedback: start cheap, escalate on failure."""
139
+ initial = self.route_cascade(text)
140
+ if initial_success:
141
+ return initial
142
+ # Escalate
143
+ escalated_tier = min(initial.tier + 1, 5)
144
+ model = TIER_TO_MODEL.get(escalated_tier, 'claude-opus-4.7')
145
+ return V10RoutingDecision(
146
+ tier=escalated_tier, model=model,
147
+ confidence=initial.tier_probs.get(escalated_tier, 0.8),
148
+ cost_estimate=TIER_COST.get(initial.tier, 0.01) + TIER_COST.get(escalated_tier, 0.30),
149
+ tier_probs=initial.tier_probs,
150
+ escalated=True,
151
+ )