narcolepticchicken commited on
Commit
0ac9bb7
Β·
verified Β·
1 Parent(s): 1ba1443

Upload training/train_router_v1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_router_v1.py +433 -0
training/train_router_v1.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Train a learned model router for Agent Cost Optimizer."""
3
+ import json, os, sys, random, pickle, uuid
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from collections import defaultdict
7
+ from typing import Dict, List, Tuple, Any, Optional
8
+
9
+ TASK_TYPES = ["quick_answer","coding","research","document_drafting",
10
+ "legal_regulated","tool_heavy","retrieval_heavy",
11
+ "long_horizon","unknown_ambiguous"]
12
+ TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)}
13
+
14
+ CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
15
+ "implement","test","compile","runtime","class","module","async","thread"]
16
+ LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"]
17
+ RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"]
18
+ TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
19
+ LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"]
20
+ MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"]
21
+
22
+ TIER_STR = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
23
+ TIER_COST = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
24
+
25
+ TASK_TEMPLATES = {
26
+ "quick_answer":["What is the capital of France?","Explain quantum computing briefly.",
27
+ "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?",
28
+ "What is the speed of light?","List the primary colors.","What is GDP?"],
29
+ "coding":["Write a Python function to reverse a linked list.",
30
+ "Fix the bug in this React component.","Refactor auth module to JWT.",
31
+ "Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
32
+ "Add unit tests for the payment module.","Optimize this SQL query.",
33
+ "Create a REST API for user management.","Implement binary search in Rust."],
34
+ "research":["Research latest transformer advances.",
35
+ "Find sources comparing LoRA and full FT.",
36
+ "Investigate data center climate impact.",
37
+ "Survey privacy-preserving ML techniques.",
38
+ "Compare reinforcement learning algorithms for robotics."],
39
+ "document_drafting":["Draft project proposal for ML pipeline.",
40
+ "Write email to team about deployment.","Create technical report on performance."],
41
+ "legal_regulated":["Review this contract for liability clauses.",
42
+ "Check GDPR compliance for data pipeline.","Draft privacy policy section.",
43
+ "Verify regulatory compliance for medical device software."],
44
+ "tool_heavy":["Search open issues and create summary.",
45
+ "Fetch API docs and generate client code.","Query Q3 sales and produce chart."],
46
+ "retrieval_heavy":["Answer based on 50-page document.",
47
+ "Find all payment processing mentions.","Retrieve relevant cases for legal query."],
48
+ "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.",
49
+ "Redesign data architecture end-to-end.","Migrate monolith to microservices."],
50
+ "unknown_ambiguous":["Help me with this thing.",
51
+ "I need something about the server.","Can you look into that issue?"],
52
+ }
53
+
54
+ def tsp(tier, diff):
55
+ return TIER_STR[tier] ** (diff * 0.6)
56
+
57
+ def extract_features(request, task_type, difficulty=3):
58
+ r = request.lower()
59
+ f = {
60
+ "req_len": len(request),
61
+ "num_words": len(request.split()),
62
+ "has_code": int(any(k in r for k in CODE_KW)),
63
+ "n_code": sum(1 for k in CODE_KW if k in r),
64
+ "has_legal": int(any(k in r for k in LEGAL_KW)),
65
+ "n_legal": sum(1 for k in LEGAL_KW if k in r),
66
+ "has_research": int(any(k in r for k in RESEARCH_KW)),
67
+ "n_research": sum(1 for k in RESEARCH_KW if k in r),
68
+ "has_tool": int(any(k in r for k in TOOL_KW)),
69
+ "n_tool": sum(1 for k in TOOL_KW if k in r),
70
+ "has_long": int(any(k in r for k in LONG_KW)),
71
+ "has_math": int(any(k in r for k in MATH_KW)),
72
+ "tt_idx": TT2IDX.get(task_type, 8),
73
+ "difficulty": difficulty,
74
+ }
75
+ for tt in TASK_TYPES:
76
+ f[f"tt_{tt}"] = int(task_type == tt)
77
+ return f
78
+
79
+ def gen_trace(idx, rng):
80
+ tt = rng.choice(list(TASK_TEMPLATES.keys()))
81
+ diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
82
+ "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt]
83
+ tier_out = {}
84
+ for t in range(1,6):
85
+ tier_out[t] = rng.random() < tsp(t, diff)
86
+ opt = 5
87
+ for t in range(1,6):
88
+ if tier_out[t]:
89
+ opt = t
90
+ break
91
+ if diff <= 2:
92
+ actual = rng.choices([1,2,3,4,5],weights=[3,4,2,1,0.5])[0]
93
+ elif diff == 3:
94
+ actual = rng.choices([1,2,3,4,5],weights=[1,2,4,2,1])[0]
95
+ elif diff == 4:
96
+ actual = rng.choices([1,2,3,4,5],weights=[0.5,1,2,4,2])[0]
97
+ else:
98
+ actual = rng.choices([1,2,3,4,5],weights=[0.2,0.5,1,3,4])[0]
99
+ outcome = "success" if tier_out[actual] else "failure"
100
+ req = rng.choice(TASK_TEMPLATES[tt])
101
+ feats = extract_features(req, tt, diff)
102
+ return {"feats":feats,"opt":opt,"actual":actual,"outcome":outcome,
103
+ "tier_out":tier_out,"tt":tt,"diff":diff,"req":req}
104
+
105
+ print("="*80)
106
+ print("AGENT COST OPTIMIZER - TRAINED ROUTER TRAINING")
107
+ print("="*80)
108
+
109
+ # ─── Generate Training Data ────────────────────────────────────────
110
+ print("\n[1] Generating 50K training traces...")
111
+ rng = random.Random(42)
112
+ traces = [gen_trace(i, rng) for i in range(50000)]
113
+ print(f" Generated {len(traces)} traces")
114
+
115
+ opt_dist = defaultdict(int)
116
+ for t in traces:
117
+ opt_dist[t["opt"]] += 1
118
+ for k in sorted(opt_dist):
119
+ print(f" opt_tier={k}: {opt_dist[k]} ({opt_dist[k]/len(traces)*100:.1f}%)")
120
+
121
+ # ─── Build Feature Matrix ──────────────────────────────────────────
122
+ print("\n[2] Building feature matrix...")
123
+
124
+ def feats_to_vec(feats):
125
+ """Convert feature dict to fixed-order numpy array."""
126
+ keys = sorted(feats.keys())
127
+ return np.array([float(feats[k]) for k in keys], dtype=np.float32)
128
+
129
+ # Get feature key order from first trace
130
+ FEAT_KEYS = sorted(traces[0]["feats"].keys())
131
+ NUM_FEATURES = len(FEAT_KEYS)
132
+ print(f" Feature count: {NUM_FEATURES}")
133
+ print(f" Features: {FEAT_KEYS}")
134
+
135
+ def feats_to_vec_safe(feats):
136
+ return np.array([float(feats.get(k, 0.0)) for k in FEAT_KEYS], dtype=np.float32)
137
+
138
+ X_all = np.array([feats_to_vec_safe(t["feats"]) for t in traces])
139
+ y_opt = np.array([t["opt"] for t in traces])
140
+ y_actual = np.array([t["actual"] for t in traces])
141
+ y_outcome = np.array([1 if t["outcome"]=="success" else 0 for t in traces])
142
+
143
+ # ─── Per-Tier Success Classifiers ───────────────────────────────────
144
+ print("\n[3] Training per-tier P(success|query) classifiers...")
145
+ from sklearn.model_selection import train_test_split
146
+ from sklearn.metrics import classification_report, accuracy_score, f1_score
147
+
148
+ # For each tier, create binary label: did this tier succeed?
149
+ per_tier_labels = {}
150
+ for tier in range(1, 6):
151
+ labels = []
152
+ for t in traces:
153
+ labels.append(1 if t["tier_out"].get(tier, False) else 0)
154
+ per_tier_labels[tier] = np.array(labels)
155
+ succ_rate = per_tier_labels[tier].mean()
156
+ print(f" Tier {tier}: success rate = {succ_rate:.3f}")
157
+
158
+ # Split train/test
159
+ X_train, X_test, idx_train, idx_test = train_test_split(
160
+ X_all, range(len(traces)), test_size=0.2, random_state=42, stratify=y_opt
161
+ )
162
+
163
+ print(f"\n Train: {len(X_train)}, Test: {len(X_test)}")
164
+
165
+ # ─── XGBoost Per-Tier Classifiers ───────────────────────────────────
166
+ print("\n[4] Training XGBoost per-tier classifiers...")
167
+ import xgboost as xgb
168
+
169
+ tier_classifiers = {}
170
+ for tier in range(1, 6):
171
+ y_train_tier = per_tier_labels[tier][idx_train]
172
+ y_test_tier = per_tier_labels[tier][idx_test]
173
+
174
+ clf = xgb.XGBClassifier(
175
+ n_estimators=100,
176
+ max_depth=4,
177
+ learning_rate=0.1,
178
+ subsample=0.8,
179
+ colsample_bytree=0.8,
180
+ objective="binary:logistic",
181
+ eval_metric="logloss",
182
+ random_state=42,
183
+ verbosity=0,
184
+ )
185
+ clf.fit(X_train, y_train_tier)
186
+
187
+ y_pred = clf.predict(X_test)
188
+ y_prob = clf.predict_proba(X_test)[:, 1]
189
+
190
+ acc = accuracy_score(y_test_tier, y_pred)
191
+ f1 = f1_score(y_test_tier, y_pred, zero_division=0)
192
+
193
+ tier_classifiers[tier] = clf
194
+ print(f" Tier {tier}: accuracy={acc:.3f}, f1={f1:.3f}")
195
+
196
+ # ─── CARROT-Style Router Decision ────────────────────────────────────
197
+ print("\n[5] Building CARROT-style router...")
198
+
199
+ def route_carrot(features_vec, tier_clfs, mu=0.7):
200
+ """Route to argmin_tier [mu*(1-P_success_tier) + (1-mu)*cost_tier].
201
+
202
+ mu controls quality-vs-cost tradeoff:
203
+ mu=1.0: maximize quality only (always frontier)
204
+ mu=0.0: minimize cost only (always cheapest)
205
+ mu=0.7: 70% quality, 30% cost (our default)
206
+ """
207
+ if features_vec.ndim == 1:
208
+ features_vec = features_vec.reshape(1, -1)
209
+
210
+ best_tier = 3
211
+ best_score = float("inf")
212
+
213
+ for tier in range(1, 6):
214
+ p_success = tier_clfs[tier].predict_proba(features_vec)[0, 1]
215
+ cost_norm = TIER_COST[tier] / TIER_COST[5] # normalize to [0,1]
216
+ score = mu * (1.0 - p_success) + (1.0 - mu) * cost_norm
217
+ if score < best_score:
218
+ best_score = score
219
+ best_tier = tier
220
+
221
+ return best_tier
222
+
223
+ # Evaluate on test set
224
+ print("\n[6] Evaluating CARROT router on test set...")
225
+
226
+ mu_values = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
227
+
228
+ for mu in mu_values:
229
+ correct = 0
230
+ total_cost = 0.0
231
+ unsafe_misses = 0
232
+ for i in idx_test:
233
+ t = traces[i]
234
+ x = feats_to_vec_safe(t["feats"]).reshape(1, -1)
235
+ pred_tier = route_carrot(x, tier_classifiers, mu=mu)
236
+ opt_tier = t["opt"]
237
+
238
+ # Check if predicted tier would succeed
239
+ would_succeed = t["tier_out"].get(pred_tier, False)
240
+ if would_succeed:
241
+ correct += 1
242
+
243
+ # Cost of predicted tier
244
+ total_cost += TIER_COST[pred_tier]
245
+
246
+ # Unsafe miss: predicted cheap tier for hard task
247
+ if pred_tier < opt_tier and not would_succeed:
248
+ unsafe_misses += 1
249
+
250
+ n_test = len(idx_test)
251
+ success_rate = correct / n_test
252
+ avg_cost = total_cost / n_test
253
+ unsafe_rate = unsafe_misses / n_test
254
+
255
+ # Compare to heuristic (task-type based)
256
+ heuristic_correct = 0
257
+ heuristic_cost = 0.0
258
+ for i in idx_test:
259
+ t = traces[i]
260
+ # Heuristic: route by task type (from classifier.py)
261
+ tt = t["tt"]
262
+ diff = t["diff"]
263
+ h_tier = min(diff + 1, 5) # simple: difficulty + 1
264
+ if t["tier_out"].get(h_tier, False):
265
+ heuristic_correct += 1
266
+ heuristic_cost += TIER_COST[h_tier]
267
+
268
+ h_success = heuristic_correct / n_test
269
+ h_avg_cost = heuristic_cost / n_test
270
+
271
+ # Frontier baseline
272
+ frontier_correct = sum(1 for i in idx_test if traces[i]["tier_out"].get(4, False) or traces[i]["tier_out"].get(5, False))
273
+ frontier_rate = frontier_correct / n_test
274
+ frontier_avg_cost = TIER_COST[4] # always tier 4
275
+
276
+ print(f"\n mu={mu:.1f}:")
277
+ print(f" CARROT: success={success_rate:.3f}, avg_cost={avg_cost:.4f}, unsafe_miss={unsafe_rate:.3f}")
278
+ print(f" Heuristic: success={h_success:.3f}, avg_cost={h_avg_cost:.4f}")
279
+ print(f" Frontier: success={frontier_rate:.3f}, avg_cost={frontier_avg_cost:.4f}")
280
+ print(f" Cost reduction vs frontier: {(1-avg_cost/frontier_avg_cost)*100:.1f}%")
281
+ print(f" Cost reduction vs heuristic: {(1-avg_cost/h_avg_cost)*100:.1f}%")
282
+
283
+ # ─── XGBoost Direct Optimal-Tier Classifier ─────────────────────────
284
+ print("\n\n[7] Training XGBoost direct optimal-tier classifier...")
285
+
286
+ y_train_opt = y_opt[idx_train] - 1 # XGB needs 0-indexed
287
+ y_test_opt = y_opt[idx_test] - 1
288
+
289
+ direct_clf = xgb.XGBClassifier(
290
+ n_estimators=200,
291
+ max_depth=5,
292
+ learning_rate=0.1,
293
+ subsample=0.8,
294
+ colsample_bytree=0.8,
295
+ objective="multi:softmax",
296
+ num_class=6,
297
+ eval_metric="mlogloss",
298
+ random_state=42,
299
+ verbosity=0,
300
+ )
301
+ direct_clf.fit(X_train, y_train_opt)
302
+
303
+ y_pred_opt = direct_clf.predict(X_test)
304
+ acc_direct = accuracy_score(y_test_opt, y_pred_opt)
305
+ print(f" Direct classifier accuracy: {acc_direct:.3f}")
306
+
307
+ # Detailed classification report
308
+ from sklearn.metrics import confusion_matrix
309
+ cm = confusion_matrix(y_test_opt, y_pred_opt, labels=[1,2,3,4,5])
310
+ print(f"\n Confusion Matrix (rows=true, cols=predicted):")
311
+ print(f" {'':>10} {'T1':>6} {'T2':>6} {'T3':>6} {'T4':>6} {'T5':>6}")
312
+ for i, tier in enumerate([1,2,3,4,5]):
313
+ row = f" True T{tier:>1}:"
314
+ for j in range(5):
315
+ row += f" {cm[i][j]:>6}"
316
+ print(row)
317
+
318
+ # Evaluate direct classifier on test set
319
+ print("\n[8] Evaluating direct optimal-tier classifier...")
320
+ direct_correct = 0
321
+ direct_cost = 0.0
322
+ direct_unsafe = 0
323
+ direct_underkill = 0
324
+ direct_overkill = 0
325
+
326
+ for i, test_idx in enumerate(idx_test):
327
+ t = traces[test_idx]
328
+ x = feats_to_vec_safe(t["feats"]).reshape(1, -1)
329
+ pred_tier = int(direct_clf.predict(x)[0]) + 1 # back to 1-indexed
330
+ opt_tier = t["opt"]
331
+
332
+ would_succeed = t["tier_out"].get(pred_tier, False)
333
+ if would_succeed:
334
+ direct_correct += 1
335
+ direct_cost += TIER_COST[pred_tier]
336
+
337
+ if pred_tier < opt_tier:
338
+ direct_underkill += 1
339
+ if not would_succeed:
340
+ direct_unsafe += 1
341
+ elif pred_tier > opt_tier:
342
+ direct_overkill += 1
343
+
344
+ n_test = len(idx_test)
345
+ print(f" Success rate: {direct_correct/n_test:.3f}")
346
+ print(f" Avg cost: {direct_cost/n_test:.4f}")
347
+ print(f" Underkill (pred < optimal): {direct_underkill/n_test:.3f}")
348
+ print(f" Overkill (pred > optimal): {direct_overkill/n_test:.3f}")
349
+ print(f" Unsafe misses: {direct_unsafe/n_test:.3f}")
350
+
351
+ # ─── Comparison: All Routers ────────────────────────────────────────
352
+ print("\n\n" + "="*80)
353
+ print("FINAL COMPARISON: ALL ROUTERS ON TEST SET")
354
+ print("="*80)
355
+
356
+ # 1. Always frontier
357
+ f_succ = sum(1 for i in idx_test if traces[i]["tier_out"].get(4,False) or traces[i]["tier_out"].get(5,False))
358
+ f_cost = TIER_COST[4] * n_test
359
+
360
+ # 2. Always cheapest
361
+ c_succ = sum(1 for i in idx_test if traces[i]["tier_out"].get(1,False) or traces[i]["tier_out"].get(2,False))
362
+ c_cost = TIER_COST[1] * n_test
363
+
364
+ # 3. Heuristic (difficulty + 1)
365
+ h_succ = 0; h_cost = 0.0
366
+ for i in idx_test:
367
+ t = traces[i]
368
+ h_tier = min(t["diff"] + 1, 5)
369
+ if t["tier_out"].get(h_tier, False): h_succ += 1
370
+ h_cost += TIER_COST[h_tier]
371
+
372
+ # 4. CARROT (best mu)
373
+ best_mu = 0.7
374
+ carrot_succ = 0; carrot_cost = 0.0; carrot_unsafe = 0
375
+ for i in idx_test:
376
+ t = traces[i]
377
+ x = feats_to_vec_safe(t["feats"]).reshape(1, -1)
378
+ pred = route_carrot(x, tier_classifiers, mu=best_mu)
379
+ if t["tier_out"].get(pred, False): carrot_succ += 1
380
+ carrot_cost += TIER_COST[pred]
381
+ if pred < t["opt"] and not t["tier_out"].get(pred, False):
382
+ carrot_unsafe += 1
383
+
384
+ # 5. Direct XGB
385
+ d_succ = direct_correct
386
+ d_cost = direct_cost
387
+ d_unsafe = direct_unsafe
388
+
389
+ # 6. Oracle (always picks optimal)
390
+ o_succ = sum(1 for i in idx_test if traces[i]["tier_out"].get(traces[i]["opt"], False))
391
+ o_cost = sum(TIER_COST[traces[i]["opt"]] for i in idx_test)
392
+
393
+ print(f"\n{'Router':<20} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10}")
394
+ print("-"*60)
395
+ for name, succ, cost, unsafe in [
396
+ ("always_frontier", f_succ, f_cost, 0),
397
+ ("always_cheap", c_succ, c_cost, 0),
398
+ ("heuristic", h_succ, h_cost, 0),
399
+ (f"CARROT(mu={best_mu})", carrot_succ, carrot_cost, carrot_unsafe),
400
+ ("direct_xgb", d_succ, d_cost, d_unsafe),
401
+ ("oracle", o_succ, o_cost, 0),
402
+ ]:
403
+ sr = succ/n_test
404
+ ac = cost/n_test
405
+ cr = (1 - cost/f_cost)*100
406
+ um = unsafe/n_test
407
+ print(f"{name:<20} {sr:>10.3f} {ac:>10.4f} {cr:>9.1f}% {um:>10.3f}")
408
+
409
+ # ─── Save Models ────────────────────────────────────────────────────
410
+ print("\n\n[9] Saving models...")
411
+ os.makedirs("/app/router_models", exist_ok=True)
412
+
413
+ # Save per-tier classifiers
414
+ for tier, clf in tier_classifiers.items():
415
+ clf.save_model(f"/app/router_models/tier_{tier}_success.json")
416
+ print(f" Saved tier_{tier}_success.json")
417
+
418
+ # Save direct classifier
419
+ direct_clf.save_model("/app/router_models/direct_optimal_tier.json")
420
+ print(f" Saved direct_optimal_tier.json")
421
+
422
+ # Save feature keys
423
+ with open("/app/router_models/feat_keys.json", "w") as f:
424
+ json.dump(FEAT_KEYS, f)
425
+ print(f" Saved feat_keys.json ({len(FEAT_KEYS)} features)")
426
+
427
+ # Save tier config
428
+ with open("/app/router_models/tier_config.json", "w") as f:
429
+ json.dump({"tier_cost": TIER_COST, "tier_str": TIER_STR}, f)
430
+ print(f" Saved tier_config.json")
431
+
432
+ print("\n\nDONE! Models saved to /app/router_models/")
433
+ print("Next step: integrate trained router into ACO ModelCascadeRouter._route_learned()")