narcolepticchicken commited on
Commit
b562f4c
Β·
verified Β·
1 Parent(s): 2912ad5

Upload training/train_router_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_router_v3.py +262 -0
training/train_router_v3.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Trained Router v3: Combined best approach.
3
+ Uses per-tier P(success) classifiers + safety floors +
4
+ cost-aware routing with ASYMMETRIC penalties (underkill penalized 5x harder than overkill).
5
+ """
6
+ import json, os, sys, random, pickle, uuid
7
+ import numpy as np
8
+ from datetime import datetime
9
+ from collections import defaultdict
10
+ from typing import Dict, List, Tuple, Any, Optional
11
+
12
+ TASK_TYPES = ["quick_answer","coding","research","document_drafting",
13
+ "legal_regulated","tool_heavy","retrieval_heavy",
14
+ "long_horizon","unknown_ambiguous"]
15
+ TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)}
16
+
17
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
18
+ "implement","test","compile","runtime","class","module","async","thread"]
19
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"]
20
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"]
21
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
22
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"]
23
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"]
24
+
25
+ TIER_STR = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
26
+ TIER_COST = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
27
+ TASK_FLOOR = {"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
28
+ "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
29
+ "tool_heavy":2,"retrieval_heavy":2}
30
+
31
+ TASK_TEMPLATES = {
32
+ "quick_answer":["What is the capital of France?","Explain quantum computing briefly.",
33
+ "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?",
34
+ "What is the speed of light?","List the primary colors.","What is GDP?"],
35
+ "coding":["Write a Python function to reverse a linked list.",
36
+ "Fix the bug in this React component.","Refactor auth module to JWT.",
37
+ "Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
38
+ "Add unit tests for the payment module.","Optimize this SQL query.",
39
+ "Create a REST API for user management.","Implement binary search in Rust."],
40
+ "research":["Research latest transformer advances.",
41
+ "Find sources comparing LoRA and full FT.",
42
+ "Investigate data center climate impact.",
43
+ "Survey privacy-preserving ML techniques.",
44
+ "Compare reinforcement learning algorithms for robotics."],
45
+ "document_drafting":["Draft project proposal for ML pipeline.",
46
+ "Write email to team about deployment.","Create technical report on performance."],
47
+ "legal_regulated":["Review this contract for liability clauses.",
48
+ "Check GDPR compliance for data pipeline.","Draft privacy policy section.",
49
+ "Verify regulatory compliance for medical device software."],
50
+ "tool_heavy":["Search open issues and create summary.",
51
+ "Fetch API docs and generate client code.","Query Q3 sales and produce chart."],
52
+ "retrieval_heavy":["Answer based on 50-page document.",
53
+ "Find all payment processing mentions.","Retrieve relevant cases for legal query."],
54
+ "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.",
55
+ "Redesign data architecture end-to-end.","Migrate monolith to microservices."],
56
+ "unknown_ambiguous":["Help me with this thing.",
57
+ "I need something about the server.","Can you look into that issue?"],
58
+ }
59
+
60
+ def tsp(tier, diff):
61
+ return TIER_STR[tier] ** (diff * 0.6)
62
+
63
+ def extract_features(request, task_type, difficulty=3):
64
+ r = request.lower()
65
+ f = {
66
+ "req_len": len(request), "num_words": len(request.split()),
67
+ "has_code": int(any(k in r for k in CODE_KW)),
68
+ "n_code": sum(1 for k in CODE_KW if k in r),
69
+ "has_legal": int(any(k in r for k in LEGAL_KW)),
70
+ "n_legal": sum(1 for k in LEGAL_KW if k in r),
71
+ "has_research": int(any(k in r for k in RESEARCH_KW)),
72
+ "n_research": sum(1 for k in RESEARCH_KW if k in r),
73
+ "has_tool": int(any(k in r for k in TOOL_KW)),
74
+ "n_tool": sum(1 for k in TOOL_KW if k in r),
75
+ "has_long": int(any(k in r for k in LONG_KW)),
76
+ "has_math": int(any(k in r for k in MATH_KW)),
77
+ "tt_idx": TT2IDX.get(task_type, 8), "difficulty": difficulty,
78
+ }
79
+ for tt in TASK_TYPES:
80
+ f[f"tt_{tt}"] = int(task_type == tt)
81
+ return f
82
+
83
+ def gen_trace(idx, rng):
84
+ tt = rng.choice(list(TASK_TEMPLATES.keys()))
85
+ diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
86
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt]
87
+ tier_out = {}
88
+ for t in range(1,6):
89
+ tier_out[t] = rng.random() < tsp(t, diff)
90
+ opt = 5
91
+ for t in range(1,6):
92
+ if tier_out[t]: opt = t; break
93
+ if diff <= 2: actual = rng.choices([1,2,3,4,5],weights=[3,4,2,1,0.5])[0]
94
+ elif diff == 3: actual = rng.choices([1,2,3,4,5],weights=[1,2,4,2,1])[0]
95
+ elif diff == 4: actual = rng.choices([1,2,3,4,5],weights=[0.5,1,2,4,2])[0]
96
+ else: actual = rng.choices([1,2,3,4,5],weights=[0.2,0.5,1,3,4])[0]
97
+ outcome = "success" if tier_out[actual] else "failure"
98
+ req = rng.choice(TASK_TEMPLATES[tt])
99
+ feats = extract_features(req, tt, diff)
100
+ return {"feats":feats,"opt":opt,"actual":actual,"outcome":outcome,
101
+ "tier_out":tier_out,"tt":tt,"diff":diff,"req":req}
102
+
103
+ print("="*80)
104
+ print("AGENT COST OPTIMIZER - TRAINED ROUTER v3")
105
+ print("Asymmetric cost scoring: underkill 5x penalty")
106
+ print("="*80)
107
+
108
+ # ─── Generate ────────────────────────────────────────────────────────
109
+ print("\n[1] Generating 50K training traces...")
110
+ rng = random.Random(42)
111
+ traces = [gen_trace(i, rng) for i in range(50000)]
112
+ FEAT_KEYS = sorted(traces[0]["feats"].keys())
113
+ def f2v(feats):
114
+ return np.array([float(feats.get(k, 0.0)) for k in FEAT_KEYS], dtype=np.float32)
115
+
116
+ X_all = np.array([f2v(t["feats"]) for t in traces])
117
+ y_opt = np.array([t["opt"] for t in traces])
118
+
119
+ per_tier_labels = {}
120
+ for tier in range(1,6):
121
+ per_tier_labels[tier] = np.array([1 if t["tier_out"].get(tier,False) else 0 for t in traces])
122
+
123
+ from sklearn.model_selection import train_test_split
124
+ from sklearn.metrics import accuracy_score, f1_score
125
+ import xgboost as xgb
126
+
127
+ X_train, X_test, idx_train, idx_test = train_test_split(X_all, range(len(traces)), test_size=0.2, random_state=42, stratify=y_opt)
128
+ print(f" Train: {len(X_train)}, Test: {len(X_test)}")
129
+
130
+ # ─── Train Per-Tier Classifiers ─────────────────────────────────────
131
+ print("\n[2] Training per-tier P(success) classifiers...")
132
+ tier_clfs = {}
133
+ for tier in range(1,6):
134
+ y_tr = per_tier_labels[tier][idx_train]
135
+ neg = (y_tr==0).sum(); pos = (y_tr==1).sum()
136
+ spw = neg / max(pos,1)
137
+ clf = xgb.XGBClassifier(n_estimators=150, max_depth=5, learning_rate=0.1,
138
+ subsample=0.8, colsample_bytree=0.8, scale_pos_weight=min(spw,5.0),
139
+ objective="binary:logistic", eval_metric="logloss", random_state=42, verbosity=0)
140
+ clf.fit(X_train, y_tr)
141
+ y_pred = clf.predict(X_test)
142
+ acc = accuracy_score(per_tier_labels[tier][idx_test], y_pred)
143
+ f1 = f1_score(per_tier_labels[tier][idx_test], y_pred, zero_division=0)
144
+ tier_clfs[tier] = clf
145
+ print(f" Tier {tier}: acc={acc:.3f}, f1={f1:.3f}")
146
+
147
+ # ─── Asymmetric Cost Router ──────────────────────────────────────────
148
+ print("\n[3] Building asymmetric cost router...")
149
+
150
+ def route_asymmetric(x, task_type, tier_clfs, underkill_penalty=5.0, overkill_penalty=1.0):
151
+ """Score each tier with asymmetric penalties.
152
+
153
+ score(tier) = P(failure@tier) * underkill_penalty * cost_of_failure
154
+ + cost(tier) * overkill_penalty
155
+
156
+ Underkill (routing too low) is penalized 5x more than overkill.
157
+ """
158
+ if x.ndim == 1:
159
+ x = x.reshape(1, -1)
160
+ floor = TASK_FLOOR.get(task_type, 2)
161
+
162
+ best_tier = floor
163
+ best_score = float("inf")
164
+
165
+ for tier in range(floor, 6):
166
+ p_fail = 1.0 - tier_clfs[tier].predict_proba(x)[0, 1]
167
+ cost_norm = TIER_COST[tier] / TIER_COST[5] # [0.03, 1.0]
168
+
169
+ # Expected cost of failure (cheap model on hard task)
170
+ failure_cost = p_fail * underkill_penalty
171
+
172
+ # Cost of using this tier (overkill penalty)
173
+ tier_cost = cost_norm * overkill_penalty
174
+
175
+ score = failure_cost + tier_cost
176
+
177
+ if score < best_score:
178
+ best_score = score
179
+ best_tier = tier
180
+
181
+ return best_tier
182
+
183
+ # ─── Evaluate ─────────────────────────────────────────────────────────
184
+ print("\n[4] Evaluating all routers...")
185
+ n_test = len(idx_test)
186
+ results = {}
187
+
188
+ def eval_router(name, route_fn):
189
+ succ = 0; cost = 0.0; unsafe = 0; fd = 0
190
+ td = defaultdict(int)
191
+ for i in idx_test:
192
+ t = traces[i]
193
+ x = f2v(t["feats"]).reshape(1,-1)
194
+ pred = route_fn(x, t)
195
+ td[pred] += 1
196
+ if t["tier_out"].get(pred, False): succ += 1
197
+ elif pred < t["opt"]: unsafe += 1
198
+ elif pred >= t["opt"]: fd += 1
199
+ cost += TIER_COST[pred]
200
+ results[name] = {"success":succ/n_test, "avg_cost":cost/n_test,
201
+ "unsafe_rate":unsafe/n_test, "false_done":fd/n_test,
202
+ "tier_dist":dict(td)}
203
+
204
+ eval_router("always_frontier", lambda x,t: 4)
205
+ eval_router("always_cheap", lambda x,t: 1)
206
+ eval_router("heuristic_diff+1", lambda x,t: min(t["diff"]+1,5))
207
+ eval_router("heuristic_floor", lambda x,t: TASK_FLOOR.get(t["tt"],2))
208
+
209
+ for ukp in [3.0, 5.0, 8.0, 10.0, 15.0]:
210
+ eval_router(f"asymmetric_uk{ukp:.0f}", lambda x,t,uk=ukp: route_asymmetric(x, t["tt"], tier_clfs, underkill_penalty=uk))
211
+
212
+ eval_router("oracle", lambda x,t: t["opt"])
213
+
214
+ # Print comparison
215
+ print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}")
216
+ print("-"*75)
217
+ fc = results["always_frontier"]["avg_cost"]
218
+ for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
219
+ cr = (1 - r["avg_cost"]/fc)*100
220
+ print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}")
221
+
222
+ # Pareto frontier
223
+ print("\nPARETO FRONTIER:")
224
+ pareto = []
225
+ for name, r in results.items():
226
+ if name == "always_cheap": continue
227
+ dominated = False
228
+ for name2, r2 in results.items():
229
+ if name == name2: continue
230
+ if r2["success"] >= r["success"] and r2["avg_cost"] <= r["avg_cost"]:
231
+ if r2["success"] > r["success"] or r2["avg_cost"] < r["avg_cost"]:
232
+ dominated = True; break
233
+ if not dominated:
234
+ pareto.append((name, r))
235
+ cr = (1 - r["avg_cost"]/fc)*100
236
+ print(f" {name:<25} success={r['success']:.3f} cost={r['avg_cost']:.4f} costRed={cr:.1f}% unsafe={r['unsafe_rate']:.3f}")
237
+
238
+ # ─── Save Best Model ─────────────────────────────────────────────────
239
+ print("\n[5] Saving models...")
240
+ os.makedirs("/app/router_models", exist_ok=True)
241
+ for tier, clf in tier_clfs.items():
242
+ clf.save_model(f"/app/router_models/tier_{tier}_success.json")
243
+ with open("/app/router_models/feat_keys.json","w") as f:
244
+ json.dump(FEAT_KEYS, f)
245
+ with open("/app/router_models/tier_config.json","w") as f:
246
+ json.dump({"tier_cost":TIER_COST,"tier_str":TIER_STR,"task_floor":TASK_FLOOR}, f, indent=2)
247
+ with open("/app/router_models/eval_results_v3.json","w") as f:
248
+ json.dump(results, f, indent=2, default=str)
249
+ print(f" Saved all models to /app/router_models/")
250
+
251
+ # Find best config
252
+ best_name = ""
253
+ best_score = -float("inf")
254
+ for name, r in results.items():
255
+ if name in ("oracle","always_cheap"): continue
256
+ # Composite: success*20 - cost*50 - unsafe*100
257
+ score = r["success"]*20 - r["avg_cost"]*50 - r["unsafe_rate"]*100
258
+ if score > best_score:
259
+ best_score = score
260
+ best_name = name
261
+ print(f"\n BEST CONFIG: {best_name} (composite score: {best_score:.2f})")
262
+ print(f"\nDONE!")