narcolepticchicken commited on
Commit
12ce4ea
Β·
verified Β·
1 Parent(s): e152f8c

Upload training/router_v8_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/router_v8_final.py +228 -0
training/router_v8_final.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Final Production Router v8: Dynamic difficulty + ML confirmation + safety floors.
3
+
4
+ This is the production router that replaces the heuristic in ACO.
5
+ """
6
+ import json, os, sys, random, uuid, pickle
7
+ import numpy as np
8
+ from collections import defaultdict
9
+
10
+ print("="*80)
11
+ print("ACO PRODUCTION ROUTER v8: DYNAMIC DIFFICULTY + ML")
12
+ print("="*80)
13
+
14
+ # ─── Load Models ──────────────────────────────────────────────────────
15
+ MODEL_DIR = "/app/router_models"
16
+ bundle = pickle.load(open(f"{MODEL_DIR}/router_bundle_v6.pkl", "rb"))
17
+ feat_keys = bundle["feat_keys"]
18
+ tier_clfs = {int(k):v for k,v in bundle["tier_clfs"].items()}
19
+ tier_calibs = {int(k):v for k,v in bundle["tier_calibrators"].items()}
20
+ TIER_COST = {int(k):v for k,v in bundle["tier_config"]["tier_cost"].items()}
21
+ TIER_STR = {int(k):v for k,v in bundle["tier_config"]["tier_str"].items()}
22
+ TASK_FLOOR = bundle["tier_config"]["task_floor"]
23
+
24
+ # ─── Feature Extraction ────────────────────────────────────────────────
25
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
26
+ "compile","runtime","segfault","thread","async","class","module"]
27
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
28
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
29
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
30
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
31
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
32
+ CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
33
+ SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
34
+ TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
35
+ "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}
36
+
37
+ def estimate_difficulty(request, task_type):
38
+ r = request.lower()
39
+ base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
40
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[task_type]
41
+ if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
42
+ if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
43
+ return base
44
+
45
+ def extract_features(request, task_type, difficulty=3):
46
+ r = request.lower()
47
+ f = {"req_len":len(request),"num_words":len(request.split()),
48
+ "has_code":int(any(k in r for k in CODE_KW)),"n_code":sum(1 for k in CODE_KW if k in r),
49
+ "has_legal":int(any(k in r for k in LEGAL_KW)),"n_legal":sum(1 for k in LEGAL_KW if k in r),
50
+ "has_research":int(any(k in r for k in RESEARCH_KW)),"n_research":sum(1 for k in RESEARCH_KW if k in r),
51
+ "has_tool":int(any(k in r for k in TOOL_KW)),"n_tool":sum(1 for k in TOOL_KW if k in r),
52
+ "has_long":int(any(k in r for k in LONG_KW)),
53
+ "has_math":int(any(k in r for k in MATH_KW)),
54
+ "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty}
55
+ for tt in TT2IDX:
56
+ f[f"tt_{tt}"] = int(task_type == tt)
57
+ return f
58
+
59
+ def f2v(feats):
60
+ return np.array([float(feats.get(k, 0.0)) for k in feat_keys], dtype=np.float32)
61
+
62
+ def get_calibrated_psuccess(x, tier):
63
+ p_raw = tier_clfs[tier].predict_proba(x)[0, 1]
64
+ return float(tier_calibs[tier].transform([p_raw])[0])
65
+
66
+ def route_production_v8(request, task_type, safety=0.30, downgrade=0.90):
67
+ diff = estimate_difficulty(request, task_type)
68
+ base = min(diff + 1, 5)
69
+ floor = TASK_FLOOR.get(task_type, 2)
70
+ base = max(base, floor)
71
+ feats = extract_features(request, task_type, diff)
72
+ x = f2v(feats).reshape(1, -1)
73
+ tier = base
74
+ ps = get_calibrated_psuccess(x, tier)
75
+ # Safety net
76
+ if ps < safety and tier < 5:
77
+ tier += 1
78
+ ps = get_calibrated_psuccess(x, tier)
79
+ # Cost saver
80
+ if tier > floor and tier == base:
81
+ cheaper = tier - 1
82
+ pc = get_calibrated_psuccess(x, cheaper)
83
+ if pc >= downgrade and cheaper >= floor:
84
+ tier = cheaper
85
+ ps = pc
86
+ return tier, ps, diff
87
+
88
+ # ─── Generate Eval Traces ────────────────────────────────────────────
89
+ TASK_TEMPLATES = {
90
+ "quick_answer":["What is the capital of France?","Explain quantum computing briefly.",
91
+ "What is 237*452?","Briefly explain photosynthesis.","Just tell me what 2+2 is.",
92
+ "Small clarification on this formula."],
93
+ "coding":["Write a Python function to reverse a linked list.",
94
+ "Fix the bug in this React component.","Refactor auth module to JWT.",
95
+ "Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
96
+ "Fix a typo in the README.","Debug this critical production segfault NOW.",
97
+ "Just fix the typo in line 42."],
98
+ "research":["Research latest transformer advances.",
99
+ "Find sources comparing LoRA and full FT.",
100
+ "Investigate data center climate impact.",
101
+ "Find sources comparing LoRA and full FT briefly."],
102
+ "document_drafting":["Draft project proposal for ML pipeline.",
103
+ "Write email to team about deployment.","Create technical report on performance."],
104
+ "legal_regulated":["Review this contract for liability clauses.",
105
+ "Check GDPR compliance for data pipeline.","Draft privacy policy section.",
106
+ "Check GDPR compliance urgently."],
107
+ "tool_heavy":["Search open issues and create summary.",
108
+ "Fetch API docs and generate client code.","Query Q3 sales and produce chart."],
109
+ "retrieval_heavy":["Answer based on 50-page document.",
110
+ "Find all payment processing mentions."],
111
+ "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.",
112
+ "Redesign data architecture end-to-end.",
113
+ "Orchestrate complete multi-region deployment."],
114
+ "unknown_ambiguous":["Help me with this thing.",
115
+ "I need something about the server."],
116
+ }
117
+
118
+ def tsp(tier, diff):
119
+ return TIER_STR[tier] ** (diff * 0.6)
120
+
121
+ print("\n[1] Generating 2K eval traces...")
122
+ rng = random.Random(999)
123
+ traces = []
124
+ for i in range(2000):
125
+ tt = rng.choice(list(TASK_TEMPLATES.keys()))
126
+ # Use STATIC difficulty for ground truth (same as heuristic)
127
+ static_diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
128
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt]
129
+ req = rng.choice(TASK_TEMPLATES[tt])
130
+ # Dynamic difficulty from request text
131
+ dyn_diff = estimate_difficulty(req, tt)
132
+
133
+ tier_out = {t: rng.random() < tsp(t, dyn_diff) for t in range(1,6)}
134
+ opt = 5
135
+ for t in range(1,6):
136
+ if tier_out[t]: opt = t; break
137
+ traces.append({"tt":tt,"static_diff":static_diff,"dyn_diff":dyn_diff,
138
+ "opt":opt,"tier_out":tier_out,"req":req})
139
+
140
+ print(f" Generated {len(traces)} traces")
141
+
142
+ # ─── Evaluate ──────────────────────────────────────────────────────────
143
+ print("\n[2] Evaluating all routers...")
144
+ n = len(traces)
145
+
146
+ def eval_router(name, route_fn):
147
+ succ=0; cost=0.0; unsafe=0; fd=0; td=defaultdict(int)
148
+ for t in traces:
149
+ pred = route_fn(t)
150
+ td[pred] += 1
151
+ if t["tier_out"].get(pred, False): succ += 1
152
+ elif pred < t["opt"]: unsafe += 1
153
+ else: fd += 1
154
+ cost += TIER_COST[pred]
155
+ return {"success":succ/n,"avg_cost":cost/n,"unsafe_rate":unsafe/n,
156
+ "false_done":fd/n,"tier_dist":dict(td)}
157
+
158
+ results = {}
159
+ results["always_frontier"] = eval_router("always_frontier", lambda t: 4)
160
+ results["always_cheap"] = eval_router("always_cheap", lambda t: 1)
161
+ results["heuristic_static"] = eval_router("heuristic_static",
162
+ lambda t: max(min(t["static_diff"]+1,5), TASK_FLOOR.get(t["tt"],2)))
163
+ results["oracle"] = eval_router("oracle", lambda t: t["opt"])
164
+
165
+ # v8 production router
166
+ results["v8_dynamic+ML"] = eval_router("v8_dynamic+ML",
167
+ lambda t: route_production_v8(t["req"], t["tt"])[0])
168
+
169
+ # v8 without ML (just dynamic difficulty)
170
+ results["v8_dynamic_only"] = eval_router("v8_dynamic_only",
171
+ lambda t: max(min(t["dyn_diff"]+1,5), TASK_FLOOR.get(t["tt"],2)))
172
+
173
+ # Print
174
+ print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}")
175
+ print("-"*75)
176
+ fc = results["always_frontier"]["avg_cost"]
177
+ for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
178
+ cr = (1-r["avg_cost"]/fc)*100
179
+ print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}")
180
+
181
+ # Per-task breakdown
182
+ print(f"\n\n[3] Per-task breakdown...")
183
+ for tt in sorted(set(t["tt"] for t in traces)):
184
+ tt_r = [t for t in traces if t["tt"] == tt]
185
+ n_tt = len(tt_r)
186
+ print(f"\n {tt} (n={n_tt}):")
187
+ for rname, rfn in [("frontier", lambda t:4),
188
+ ("heuristic", lambda t:max(min(t["static_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))),
189
+ ("v8_dynamic", lambda t:max(min(t["dyn_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))),
190
+ ("v8_full", lambda t:route_production_v8(t["req"],t["tt"])[0]),
191
+ ("oracle", lambda t:t["opt"])]:
192
+ succ = sum(1 for t in tt_r if t["tier_out"].get(rfn(t), False))
193
+ cost = sum(TIER_COST[rfn(t)] for t in tt_r)
194
+ sr = succ/n_tt; ac = cost/n_tt
195
+ cr = (1-ac/fc)*100
196
+ print(f" {rname:<14} success={sr:.3f} cost={ac:.4f} costRed={cr:.1f}%")
197
+
198
+ # Save
199
+ with open("/app/router_models/v8_final_results.json","w") as f:
200
+ json.dump(results, f, indent=2, default=str)
201
+
202
+ # Save v8 bundle
203
+ v8_bundle = {
204
+ "tier_clfs": {str(k):v for k,v in tier_clfs.items()},
205
+ "tier_calibrators": {str(k):v for k,v in tier_calibs.items()},
206
+ "feat_keys": feat_keys,
207
+ "tier_config": {str(k):v for k,v in TIER_COST.items()},
208
+ "task_floor": TASK_FLOOR,
209
+ "version": "8.0",
210
+ "description": "ACO Production Router v8: dynamic difficulty + ML confirmation + safety floors",
211
+ "dynamic_difficulty": True,
212
+ "critical_keywords": CRITICAL_KW,
213
+ "simple_keywords": SIMPLE_KW,
214
+ }
215
+ with open("/app/router_models/router_bundle_v8.pkl","wb") as f:
216
+ pickle.dump(v8_bundle, f)
217
+
218
+ print(f"\n\n{'='*80}")
219
+ print("FINAL v8 RESULTS")
220
+ print(f"{'='*80}")
221
+ print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10}")
222
+ print("-"*65)
223
+ for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
224
+ cr = (1-r["avg_cost"]/fc)*100
225
+ print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f}")
226
+
227
+ print(f"\nSaved router_bundle_v8.pkl ({os.path.getsize('/app/router_models/router_bundle_v8.pkl')/1024:.0f} KB)")
228
+ print(f"DONE!")