narcolepticchicken commited on
Commit
f654d83
Β·
verified Β·
1 Parent(s): b562f4c

Upload training/train_router_v4_production.py with huggingface_hub

Browse files
training/train_router_v4_production.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Production Trained Router v4: Difficulty-first + ML refinement.
3
+
4
+ Architecture:
5
+ 1. Predict difficulty (task_type -> difficulty mapping + ML difficulty classifier)
6
+ 2. Convert difficulty to base_tier (difficulty + 1, capped at 5)
7
+ 3. Apply safety floor per task_type
8
+ 4. Use per-tier P(success) classifiers to CONFIRM or ESCALATE
9
+ 5. If P(success@base_tier) < 0.65, escalate to next tier
10
+
11
+ This combines the best of both worlds:
12
+ - Difficulty mapping (heuristic) is reliable and interpretable
13
+ - ML classifiers add a safety net: they catch cases where difficulty is underestimated
14
+ - Safety floors prevent dangerous under-routing on legal/critical tasks
15
+ """
16
+ import json, os, sys, random, uuid
17
+ import numpy as np
18
+ from datetime import datetime
19
+ from collections import defaultdict
20
+ from typing import Dict, List, Optional, Any
21
+
22
+ import xgboost as xgb
23
+
24
+ # ─── Load Models ──────────────────────────────────────────────────────
25
+ MODEL_DIR = "/app/router_models"
26
+ feat_keys = json.load(open(f"{MODEL_DIR}/feat_keys.json"))
27
+ tier_config = json.load(open(f"{MODEL_DIR}/tier_config.json"))
28
+ TIER_COST = {int(k):v for k,v in tier_config["tier_cost"].items()}
29
+ TIER_STR = {int(k):v for k,v in tier_config["tier_str"].items()}
30
+ TASK_FLOOR = tier_config["task_floor"]
31
+
32
+ tier_clfs = {}
33
+ for tier in range(1, 6):
34
+ clf = xgb.XGBClassifier()
35
+ clf.load_model(f"{MODEL_DIR}/tier_{tier}_success.json")
36
+ tier_clfs[tier] = clf
37
+
38
+ # ─── Feature Extraction ────────────────────────────────────────────────
39
+ TASK_TYPES = ["quick_answer","coding","research","document_drafting",
40
+ "legal_regulated","tool_heavy","retrieval_heavy",
41
+ "long_horizon","unknown_ambiguous"]
42
+ TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)}
43
+
44
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
45
+ "implement","test","compile","runtime","class","module","async","thread"]
46
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"]
47
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"]
48
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
49
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"]
50
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"]
51
+
52
+ def extract_features(request, task_type, difficulty=3):
53
+ r = request.lower()
54
+ f = {"req_len":len(request),"num_words":len(request.split()),
55
+ "has_code":int(any(k in r for k in CODE_KW)),
56
+ "n_code":sum(1 for k in CODE_KW if k in r),
57
+ "has_legal":int(any(k in r for k in LEGAL_KW)),
58
+ "n_legal":sum(1 for k in LEGAL_KW if k in r),
59
+ "has_research":int(any(k in r for k in RESEARCH_KW)),
60
+ "n_research":sum(1 for k in RESEARCH_KW if k in r),
61
+ "has_tool":int(any(k in r for k in TOOL_KW)),
62
+ "n_tool":sum(1 for k in TOOL_KW if k in r),
63
+ "has_long":int(any(k in r for k in LONG_KW)),
64
+ "has_math":int(any(k in r for k in MATH_KW)),
65
+ "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty}
66
+ for tt in TASK_TYPES:
67
+ f[f"tt_{tt}"] = int(task_type == tt)
68
+ return f
69
+
70
+ def f2v(feats):
71
+ return np.array([float(feats.get(k,0.0)) for k in feat_keys], dtype=np.float32)
72
+
73
+ # ─── Production Router ─────────────────────────────────────────────────
74
+ def route_production(request, task_type, difficulty, escalation_threshold=0.65):
75
+ """Production router: difficulty-first + ML confirmation.
76
+
77
+ Step 1: difficulty -> base_tier (difficulty + 1, capped at 5)
78
+ Step 2: Apply safety floor (task_type -> minimum tier)
79
+ Step 3: base_tier = max(base_tier, safety_floor)
80
+ Step 4: If P(success@base_tier) < escalation_threshold, escalate one tier
81
+ Step 5: Never exceed tier 5
82
+ """
83
+ base_tier = min(difficulty + 1, 5)
84
+ floor = TASK_FLOOR.get(task_type, 2)
85
+ base_tier = max(base_tier, floor)
86
+
87
+ # ML confirmation: check if base_tier is likely to succeed
88
+ feats = extract_features(request, task_type, difficulty)
89
+ x = f2v(feats).reshape(1, -1)
90
+
91
+ p_success = tier_clfs[base_tier].predict_proba(x)[0, 1]
92
+
93
+ # If P(success) at base_tier is too low, escalate
94
+ while p_success < escalation_threshold and base_tier < 5:
95
+ base_tier += 1
96
+ p_success = tier_clfs[base_tier].predict_proba(x)[0, 1]
97
+
98
+ return base_tier
99
+
100
+ # ─── Generate Evaluation ────────────────────────────────────────────────
101
+ TASK_TEMPLATES_EVAL = {
102
+ "quick_answer":["What is the capital of France?","Explain quantum computing briefly.",
103
+ "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?"],
104
+ "coding":["Write a Python function to reverse a linked list.",
105
+ "Fix the bug in this React component.","Refactor auth module to JWT.",
106
+ "Implement LRU cache in Go.","Debug segfault in C++ thread pool."],
107
+ "research":["Research latest transformer advances.",
108
+ "Find sources comparing LoRA and full FT.",
109
+ "Investigate data center climate impact."],
110
+ "document_drafting":["Draft project proposal for ML pipeline.",
111
+ "Write email to team about deployment.","Create technical report on performance."],
112
+ "legal_regulated":["Review this contract for liability clauses.",
113
+ "Check GDPR compliance for data pipeline.","Draft privacy policy section."],
114
+ "tool_heavy":["Search open issues and create summary.",
115
+ "Fetch API docs and generate client code.","Query Q3 sales and produce chart."],
116
+ "retrieval_heavy":["Answer based on 50-page document.",
117
+ "Find all payment processing mentions.","Retrieve relevant cases for legal query."],
118
+ "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.",
119
+ "Redesign data architecture end-to-end."],
120
+ "unknown_ambiguous":["Help me with this thing.",
121
+ "I need something about the server.","Can you look into that issue?"],
122
+ }
123
+
124
+ def tsp(tier, diff):
125
+ s = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}[tier]
126
+ return s ** (diff * 0.6)
127
+
128
+ print("="*80)
129
+ print("PRODUCTION TRAINED ROUTER v4 BENCHMARK")
130
+ print("="*80)
131
+ print("\nArchitecture: difficulty-first + ML confirmation + safety floors")
132
+
133
+ print("\n[1] Generating 2K eval traces (seed=999)...")
134
+ eval_rng = random.Random(999)
135
+ eval_traces = []
136
+ for i in range(2000):
137
+ tt = eval_rng.choice(list(TASK_TEMPLATES_EVAL.keys()))
138
+ diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
139
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt]
140
+ tier_out = {t: eval_rng.random() < tsp(t, diff) for t in range(1,6)}
141
+ opt = 5
142
+ for t in range(1,6):
143
+ if tier_out[t]: opt = t; break
144
+ req = eval_rng.choice(TASK_TEMPLATES_EVAL[tt])
145
+ eval_traces.append({"tt":tt,"diff":diff,"opt":opt,"tier_out":tier_out,"req":req})
146
+ print(f" Generated {len(eval_traces)} traces")
147
+
148
+ # ─── Evaluate ──────────────────────────────────────────────────────────
149
+ print("\n[2] Evaluating all routers...")
150
+
151
+ def eval_router(name, route_fn):
152
+ succ=0; cost=0.0; unsafe=0; fd=0; td=defaultdict(int)
153
+ for t in eval_traces:
154
+ pred = route_fn(t)
155
+ td[pred] += 1
156
+ if t["tier_out"].get(pred, False): succ += 1
157
+ elif pred < t["opt"]: unsafe += 1
158
+ else: fd += 1
159
+ cost += TIER_COST[pred]
160
+ n = len(eval_traces)
161
+ return {"success":succ/n, "avg_cost":cost/n, "unsafe_rate":unsafe/n,
162
+ "false_done":fd/n, "tier_dist":dict(td)}
163
+
164
+ # Baselines
165
+ results = {}
166
+ results["always_frontier"] = eval_router("always_frontier", lambda t: 4)
167
+ results["always_cheap"] = eval_router("always_cheap", lambda t: 1)
168
+ results["heuristic_diff+1"] = eval_router("heuristic_diff+1", lambda t: min(t["diff"]+1, 5))
169
+ results["heuristic_floor"] = eval_router("heuristic_floor", lambda t: TASK_FLOOR.get(t["tt"], 2))
170
+ results["oracle"] = eval_router("oracle", lambda t: t["opt"])
171
+
172
+ # Production router at different escalation thresholds
173
+ for threshold in [0.50, 0.55, 0.60, 0.65, 0.70, 0.75]:
174
+ name = f"prod_t{threshold:.2f}"
175
+ results[name] = eval_router(name,
176
+ lambda t, th=threshold: route_production(t["req"], t["tt"], t["diff"], escalation_threshold=th))
177
+
178
+ # Print comparison
179
+ print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}")
180
+ print("-"*75)
181
+ fc = results["always_frontier"]["avg_cost"]
182
+ for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
183
+ cr = (1 - r["avg_cost"]/fc)*100
184
+ print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}")
185
+
186
+ # Find best production config
187
+ print("\n\n[3] Finding best production config...")
188
+ best_name = ""
189
+ best_score = -float("inf")
190
+ for name, r in results.items():
191
+ if name in ("oracle","always_cheap"): continue
192
+ # Composite: maximize success, minimize cost and unsafe
193
+ score = r["success"]*20 - r["avg_cost"]*30 - r["unsafe_rate"]*100
194
+ if score > best_score:
195
+ best_score = score
196
+ best_name = name
197
+
198
+ print(f" Best: {best_name} (composite: {best_score:.2f})")
199
+ print(f" Success: {results[best_name]['success']:.3f}")
200
+ print(f" Avg cost: {results[best_name]['avg_cost']:.4f}")
201
+ fc_val = results["always_frontier"]["avg_cost"]
202
+ cr = (1 - results[best_name]["avg_cost"]/fc_val)*100
203
+ print(f" Cost reduction vs frontier: {cr:.1f}%")
204
+ print(f" Unsafe rate: {results[best_name]['unsafe_rate']:.3f}")
205
+ print(f" False-DONE rate: {results[best_name]['false_done']:.3f}")
206
+
207
+ # Per-task breakdown for best
208
+ print(f"\n\n[4] Per-task breakdown for {best_name}...")
209
+ for tt in sorted(set(t["tt"] for t in eval_traces)):
210
+ tt_traces = [t for t in eval_traces if t["tt"] == tt]
211
+ n_tt = len(tt_traces)
212
+ if n_tt == 0: continue
213
+
214
+ for rname, rfn in [("frontier", lambda t:4),
215
+ ("heuristic", lambda t:min(t["diff"]+1,5)),
216
+ ("trained", lambda t:route_production(t["req"],t["tt"],t["diff"],
217
+ escalation_threshold=float(best_name.split("t")[1]))),
218
+ ("oracle", lambda t:t["opt"])]:
219
+ succ = sum(1 for t in tt_traces if t["tier_out"].get(rfn(t), False))
220
+ cost = sum(TIER_COST[rfn(t)] for t in tt_traces)
221
+ sr = succ/n_tt; ac = cost/n_tt
222
+ if rname == "frontier":
223
+ print(f"\n {tt} (n={n_tt}):")
224
+ cr = (1 - ac/fc_val)*100
225
+ print(f" {rname:<12} success={sr:.3f} cost={ac:.4f} costRed={cr:.1f}%")
226
+
227
+ # Pareto
228
+ print("\n\nPARETO FRONTIER:")
229
+ for name, r in results.items():
230
+ if name == "always_cheap": continue
231
+ dominated = False
232
+ for name2, r2 in results.items():
233
+ if name == name2: continue
234
+ if r2["success"] >= r["success"] and r2["avg_cost"] <= r["avg_cost"]:
235
+ if r2["success"] > r["success"] or r2["avg_cost"] < r["avg_cost"]:
236
+ dominated = True; break
237
+ if not dominated:
238
+ cr = (1 - r["avg_cost"]/fc_val)*100
239
+ print(f" {name:<25} success={r['success']:.3f} cost={r['avg_cost']:.4f} costRed={cr:.1f}% unsafe={r['unsafe_rate']:.3f}")
240
+
241
+ print(f"\nDONE!")