| |
| """Trained Router v2: Safety-first CARROT with tuned mu + safety floors. |
| |
| Key insight from v1: |
| - Per-tier P(success) classifiers work well individually |
| - CARROT routing with mu=0.6 beats heuristic on both quality and cost |
| - But success rate drops because CARROT routes cheap for hard tasks |
| |
| Solution: Add SAFETY FLOORS per task type: |
| - legal_regulated: never below tier 4 |
| - coding/research with legal kw: never below tier 3 |
| - Use P(success) > threshold as gate, fallback to difficulty-based tier |
| - When confidence is low, default to tier 3 (medium) |
| """ |
| import json, os, sys, random, pickle, uuid |
| import numpy as np |
| from datetime import datetime |
| from collections import defaultdict |
|
|
| TASK_TYPES = ["quick_answer","coding","research","document_drafting", |
| "legal_regulated","tool_heavy","retrieval_heavy", |
| "long_horizon","unknown_ambiguous"] |
| TT2IDX = {t:i for i,t in enumerate(TASK_TYPES)} |
|
|
| CODE_KW = ["python","javascript","code","function","bug","debug","refactor", |
| "implement","test","compile","runtime","class","module","async","thread"] |
| LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability"] |
| RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey"] |
| TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] |
| LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy"] |
| MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability"] |
|
|
| TIER_STR = {1:0.35,2:0.55,3:0.80,4:0.93,5:0.97} |
| TIER_COST = {1:0.05,2:0.15,3:0.75,4:1.0,5:1.5} |
|
|
| TASK_TEMPLATES = { |
| "quick_answer":["What is the capital of France?","Explain quantum computing briefly.", |
| "What is 237*452?","Define photosynthesis.","Who wrote Hamlet?", |
| "What is the speed of light?","List the primary colors.","What is GDP?"], |
| "coding":["Write a Python function to reverse a linked list.", |
| "Fix the bug in this React component.","Refactor auth module to JWT.", |
| "Implement LRU cache in Go.","Debug segfault in C++ thread pool.", |
| "Add unit tests for the payment module.","Optimize this SQL query.", |
| "Create a REST API for user management.","Implement binary search in Rust."], |
| "research":["Research latest transformer advances.", |
| "Find sources comparing LoRA and full FT.", |
| "Investigate data center climate impact.", |
| "Survey privacy-preserving ML techniques.", |
| "Compare reinforcement learning algorithms for robotics."], |
| "document_drafting":["Draft project proposal for ML pipeline.", |
| "Write email to team about deployment.","Create technical report on performance."], |
| "legal_regulated":["Review this contract for liability clauses.", |
| "Check GDPR compliance for data pipeline.","Draft privacy policy section.", |
| "Verify regulatory compliance for medical device software."], |
| "tool_heavy":["Search open issues and create summary.", |
| "Fetch API docs and generate client code.","Query Q3 sales and produce chart."], |
| "retrieval_heavy":["Answer based on 50-page document.", |
| "Find all payment processing mentions.","Retrieve relevant cases for legal query."], |
| "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.", |
| "Redesign data architecture end-to-end.","Migrate monolith to microservices."], |
| "unknown_ambiguous":["Help me with this thing.", |
| "I need something about the server.","Can you look into that issue?"], |
| } |
|
|
| |
| TASK_FLOOR = { |
| "legal_regulated": 4, |
| "long_horizon": 3, |
| "research": 3, |
| "coding": 3, |
| "unknown_ambiguous": 3, |
| "quick_answer": 1, |
| "document_drafting": 2, |
| "tool_heavy": 2, |
| "retrieval_heavy": 2, |
| } |
|
|
| def tsp(tier, diff): |
| return TIER_STR[tier] ** (diff * 0.6) |
|
|
| def extract_features(request, task_type, difficulty=3): |
| r = request.lower() |
| f = { |
| "req_len": len(request), |
| "num_words": len(request.split()), |
| "has_code": int(any(k in r for k in CODE_KW)), |
| "n_code": sum(1 for k in CODE_KW if k in r), |
| "has_legal": int(any(k in r for k in LEGAL_KW)), |
| "n_legal": sum(1 for k in LEGAL_KW if k in r), |
| "has_research": int(any(k in r for k in RESEARCH_KW)), |
| "n_research": sum(1 for k in RESEARCH_KW if k in r), |
| "has_tool": int(any(k in r for k in TOOL_KW)), |
| "n_tool": sum(1 for k in TOOL_KW if k in r), |
| "has_long": int(any(k in r for k in LONG_KW)), |
| "has_math": int(any(k in r for k in MATH_KW)), |
| "tt_idx": TT2IDX.get(task_type, 8), |
| "difficulty": difficulty, |
| } |
| for tt in TASK_TYPES: |
| f[f"tt_{tt}"] = int(task_type == tt) |
| return f |
|
|
| def gen_trace(idx, rng): |
| tt = rng.choice(list(TASK_TEMPLATES.keys())) |
| diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, |
| "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt] |
| tier_out = {} |
| for t in range(1,6): |
| tier_out[t] = rng.random() < tsp(t, diff) |
| opt = 5 |
| for t in range(1,6): |
| if tier_out[t]: |
| opt = t |
| break |
| if diff <= 2: |
| actual = rng.choices([1,2,3,4,5],weights=[3,4,2,1,0.5])[0] |
| elif diff == 3: |
| actual = rng.choices([1,2,3,4,5],weights=[1,2,4,2,1])[0] |
| elif diff == 4: |
| actual = rng.choices([1,2,3,4,5],weights=[0.5,1,2,4,2])[0] |
| else: |
| actual = rng.choices([1,2,3,4,5],weights=[0.2,0.5,1,3,4])[0] |
| outcome = "success" if tier_out[actual] else "failure" |
| req = rng.choice(TASK_TEMPLATES[tt]) |
| feats = extract_features(req, tt, diff) |
| return {"feats":feats,"opt":opt,"actual":actual,"outcome":outcome, |
| "tier_out":tier_out,"tt":tt,"diff":diff,"req":req} |
|
|
| print("="*80) |
| print("AGENT COST OPTIMIZER - TRAINED ROUTER v2 (Safety-First CARROT)") |
| print("="*80) |
|
|
| print("\n[1] Generating 50K training traces...") |
| rng = random.Random(42) |
| traces = [gen_trace(i, rng) for i in range(50000)] |
| print(f" Generated {len(traces)} traces") |
|
|
| |
| FEAT_KEYS = sorted(traces[0]["feats"].keys()) |
| NUM_FEATURES = len(FEAT_KEYS) |
|
|
| def f2v(feats): |
| return np.array([float(feats.get(k, 0.0)) for k in FEAT_KEYS], dtype=np.float32) |
|
|
| X_all = np.array([f2v(t["feats"]) for t in traces]) |
| y_opt = np.array([t["opt"] for t in traces]) |
|
|
| |
| per_tier_labels = {} |
| for tier in range(1, 6): |
| per_tier_labels[tier] = np.array([1 if t["tier_out"].get(tier, False) else 0 for t in traces]) |
|
|
| |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| X_train, X_test, idx_train, idx_test = train_test_split( |
| X_all, range(len(traces)), test_size=0.2, random_state=42, stratify=y_opt |
| ) |
| print(f" Train: {len(X_train)}, Test: {len(X_test)}") |
|
|
| |
| print("\n[2] Training per-tier P(success) XGBoost classifiers...") |
| import xgboost as xgb |
|
|
| tier_clfs = {} |
| for tier in range(1, 6): |
| y_tr = per_tier_labels[tier][idx_train] |
| y_te = per_tier_labels[tier][idx_test] |
|
|
| |
| neg = (y_tr == 0).sum() |
| pos = (y_tr == 1).sum() |
| spw = neg / max(pos, 1) |
|
|
| clf = xgb.XGBClassifier( |
| n_estimators=150, max_depth=5, learning_rate=0.1, |
| subsample=0.8, colsample_bytree=0.8, |
| scale_pos_weight=min(spw, 5.0), |
| objective="binary:logistic", eval_metric="logloss", |
| random_state=42, verbosity=0, |
| ) |
| clf.fit(X_train, y_tr) |
|
|
| y_pred = clf.predict(X_test) |
| acc = accuracy_score(y_te, y_pred) |
| f1 = f1_score(y_te, y_pred, zero_division=0) |
| tier_clfs[tier] = clf |
| print(f" Tier {tier}: acc={acc:.3f}, f1={f1:.3f}, spw={spw:.2f}") |
|
|
| |
| print("\n[3] Building safety-first CARROT router...") |
|
|
| def route_safe_carrot(features_vec, tier_clfs, task_type, mu=0.7, |
| success_threshold=0.5, safety_floor=None): |
| """Route with safety floors. |
| |
| 1. Compute P(success|tier) for each tier |
| 2. Apply safety floor per task type |
| 3. Pick cheapest tier where P(success) > threshold |
| 4. If none meets threshold, escalate to next tier |
| """ |
| if features_vec.ndim == 1: |
| features_vec = features_vec.reshape(1, -1) |
|
|
| floor = safety_floor or TASK_FLOOR.get(task_type, 2) |
|
|
| |
| p_success = {} |
| for tier in range(1, 6): |
| p_success[tier] = tier_clfs[tier].predict_proba(features_vec)[0, 1] |
|
|
| |
| for tier in range(floor, 6): |
| if p_success[tier] >= success_threshold: |
| return tier, p_success |
|
|
| |
| for tier in range(floor + 1, 6): |
| if p_success[tier] >= success_threshold * 0.8: |
| return tier, p_success |
|
|
| |
| best_tier = floor |
| best_score = float("inf") |
| for tier in range(floor, 6): |
| cost_norm = TIER_COST[tier] / TIER_COST[5] |
| score = mu * (1.0 - p_success[tier]) + (1.0 - mu) * cost_norm |
| if score < best_score: |
| best_score = score |
| best_tier = tier |
|
|
| return best_tier, p_success |
|
|
| |
| print("\n[4] Evaluating all routers on test set...") |
|
|
| n_test = len(idx_test) |
| results = {} |
|
|
| |
| def eval_router(name, route_fn): |
| succ = 0; cost = 0.0; unsafe = 0; false_done = 0 |
| tier_dist = defaultdict(int) |
| for i in idx_test: |
| t = traces[i] |
| x = f2v(t["feats"]).reshape(1, -1) |
| pred, _ = route_fn(x, t) |
| tier_dist[pred] += 1 |
| if t["tier_out"].get(pred, False): |
| succ += 1 |
| else: |
| if pred < t["opt"]: |
| unsafe += 1 |
| if pred >= t["opt"] and not t["tier_out"].get(pred, False): |
| false_done += 1 |
| cost += TIER_COST[pred] |
| results[name] = { |
| "success": succ/n_test, "avg_cost": cost/n_test, |
| "unsafe_rate": unsafe/n_test, "false_done": false_done/n_test, |
| "tier_dist": dict(tier_dist), |
| } |
|
|
| |
| eval_router("always_frontier", lambda x, t: (4, {})) |
|
|
| |
| eval_router("always_cheap", lambda x, t: (1, {})) |
|
|
| |
| eval_router("heuristic_diff+1", lambda x, t: (min(t["diff"]+1, 5), {})) |
|
|
| |
| eval_router("heuristic_floor", lambda x, t: (TASK_FLOOR.get(t["tt"], 3), {})) |
|
|
| |
| def carrot_v1(x, t): |
| ps = {tier: tier_clfs[tier].predict_proba(x)[0,1] for tier in range(1,6)} |
| best = 3; best_s = float("inf") |
| for tier in range(1,6): |
| s = 0.6*(1-ps[tier]) + 0.4*(TIER_COST[tier]/TIER_COST[5]) |
| if s < best_s: best_s = s; best = tier |
| return best, ps |
| eval_router("CARROT_v1_mu0.6", carrot_v1) |
|
|
| |
| def safe_carrot_050(x, t): |
| return route_safe_carrot(x, tier_clfs, t["tt"], mu=0.7, success_threshold=0.5) |
| eval_router("safe_CARROT_t0.50", safe_carrot_050) |
|
|
| |
| def safe_carrot_060(x, t): |
| return route_safe_carrot(x, tier_clfs, t["tt"], mu=0.7, success_threshold=0.6) |
| eval_router("safe_CARROT_t0.60", safe_carrot_060) |
|
|
| |
| def safe_carrot_065(x, t): |
| return route_safe_carrot(x, tier_clfs, t["tt"], mu=0.7, success_threshold=0.65) |
| eval_router("safe_CARROT_t0.65", safe_carrot_065) |
|
|
| |
| eval_router("oracle", lambda x, t: (t["opt"], {})) |
|
|
| |
| print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}") |
| print("-"*75) |
| frontier_cost = results["always_frontier"]["avg_cost"] |
| for name, r in sorted(results.items(), key=lambda x: -x[1]["success"]): |
| cr = (1 - r["avg_cost"]/frontier_cost)*100 |
| print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}") |
|
|
| |
| print("\n\n[5] Training improved direct classifier (0-indexed)...") |
|
|
| y_train_direct = y_opt[idx_train] - 1 |
| y_test_direct = y_opt[idx_test] - 1 |
|
|
| |
| from sklearn.utils.class_weight import compute_sample_weight |
|
|
| |
| sample_weights = [] |
| for i in idx_train: |
| t = traces[i] |
| opt = t["opt"] |
| |
| sample_weights.append(1.0) |
| sample_weights = np.array(sample_weights) |
|
|
| direct_clf = xgb.XGBClassifier( |
| n_estimators=300, max_depth=6, learning_rate=0.05, |
| subsample=0.8, colsample_bytree=0.8, |
| objective="multi:softmax", num_class=5, |
| eval_metric="mlogloss", random_state=42, verbosity=0, |
| ) |
| direct_clf.fit(X_train, y_train_direct, sample_weight=sample_weights) |
|
|
| y_pred_direct = direct_clf.predict(X_test) + 1 |
| acc = accuracy_score(y_opt[idx_test], y_pred_direct) |
| print(f" Direct classifier accuracy: {acc:.3f}") |
|
|
| |
| def direct_safe(x, t): |
| pred = int(direct_clf.predict(x)[0]) + 1 |
| floor = TASK_FLOOR.get(t["tt"], 2) |
| return max(pred, floor), {} |
|
|
| eval_router("direct_safe_xgb", direct_safe) |
|
|
| |
| print("\n\n[6] Feature importance (from direct classifier)...") |
| imp = direct_clf.feature_importances_ |
| for feat, score in sorted(zip(FEAT_KEYS, imp), key=lambda x: -x[1])[:10]: |
| print(f" {feat:<25}: {score:.4f}") |
|
|
| |
| print("\n\n[7] Saving models...") |
| os.makedirs("/app/router_models", exist_ok=True) |
| for tier, clf in tier_clfs.items(): |
| clf.save_model(f"/app/router_models/tier_{tier}_success.json") |
| direct_clf.save_model("/app/router_models/direct_optimal_tier.json") |
| with open("/app/router_models/feat_keys.json", "w") as f: |
| json.dump(FEAT_KEYS, f) |
| with open("/app/router_models/tier_config.json", "w") as f: |
| json.dump({"tier_cost": TIER_COST, "tier_str": TIER_STR, "task_floor": TASK_FLOOR}, f, indent=2) |
|
|
| |
| print(f"\n\n{'='*80}") |
| print("FINAL COMPARISON (ALL ROUTERS)") |
| print(f"{'='*80}") |
| print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}") |
| print("-"*75) |
| frontier_cost = results["always_frontier"]["avg_cost"] |
| for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])): |
| cr = (1 - r["avg_cost"]/frontier_cost)*100 |
| print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}") |
|
|
| print(f"\n\nDONE! Models saved to /app/router_models/") |
|
|
| |
| print("\n\n[8] Training RouteLLM-style binary classifiers...") |
| print(" (For each tier pair, train: should we route to cheaper or more expensive tier?)") |
|
|
| |
| |
| |
|
|
| boundary_clfs = {} |
| for boundary in range(2, 6): |
| |
| |
| y_boundary = np.array([1 if traces[i]["opt"] < boundary else 0 for i in range(len(traces))]) |
|
|
| y_tr = y_boundary[idx_train] |
| y_te = y_boundary[idx_test] |
|
|
| neg = (y_tr == 0).sum() |
| pos = (y_tr == 1).sum() |
| spw = neg / max(pos, 1) |
|
|
| clf = xgb.XGBClassifier( |
| n_estimators=150, max_depth=5, learning_rate=0.1, |
| subsample=0.8, colsample_bytree=0.8, |
| scale_pos_weight=min(spw, 3.0), |
| objective="binary:logistic", eval_metric="logloss", |
| random_state=42, verbosity=0, |
| ) |
| clf.fit(X_train, y_tr) |
|
|
| y_pred = clf.predict(X_test) |
| acc = accuracy_score(y_te, y_pred) |
| f1 = f1_score(y_te, y_pred, zero_division=0) |
|
|
| boundary_clfs[boundary] = clf |
| rate = (y_tr == 0).mean() |
| print(f" Boundary {boundary}: acc={acc:.3f}, f1={f1:.3f}, needs_tier={rate:.3f}") |
|
|
| def route_cascade_binary(x, t): |
| """RouteLLM-style cascade: check each boundary, route to first that passes.""" |
| if x.ndim == 1: |
| x = x.reshape(1, -1) |
| floor = TASK_FLOOR.get(t["tt"], 2) |
| |
| |
| current_tier = floor |
| |
| for boundary in range(floor + 1, 6): |
| |
| |
| |
| p_need_higher = boundary_clfs[boundary].predict_proba(x)[0, 0] |
| if p_need_higher > 0.4: |
| current_tier = boundary |
| else: |
| break |
| |
| return current_tier, {} |
|
|
| eval_router("cascade_binary_t0.4", route_cascade_binary) |
|
|
| def route_cascade_binary_t050(x, t): |
| if x.ndim == 1: x = x.reshape(1, -1) |
| floor = TASK_FLOOR.get(t["tt"], 2) |
| current_tier = floor |
| for boundary in range(floor + 1, 6): |
| p_need = boundary_clfs[boundary].predict_proba(x)[0, 0] |
| if p_need > 0.5: |
| current_tier = boundary |
| else: |
| break |
| return current_tier, {} |
|
|
| eval_router("cascade_binary_t0.5", route_cascade_binary_t050) |
|
|
| def route_cascade_binary_t030(x, t): |
| if x.ndim == 1: x = x.reshape(1, -1) |
| floor = TASK_FLOOR.get(t["tt"], 2) |
| current_tier = floor |
| for boundary in range(floor + 1, 6): |
| p_need = boundary_clfs[boundary].predict_proba(x)[0, 0] |
| if p_need > 0.3: |
| current_tier = boundary |
| else: |
| break |
| return current_tier, {} |
|
|
| eval_router("cascade_binary_t0.3", route_cascade_binary_t030) |
|
|
| |
| for boundary, clf in boundary_clfs.items(): |
| clf.save_model(f"/app/router_models/boundary_{boundary}.json") |
| print(f" Saved boundary_{boundary}.json") |
|
|
| |
| print(f"\n\n{'='*80}") |
| print("FINAL COMPARISON v2 (WITH BINARY CASCADE ROUTER)") |
| print(f"{'='*80}") |
| print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}") |
| print("-"*75) |
| frontier_cost = results["always_frontier"]["avg_cost"] |
| for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])): |
| cr = (1 - r["avg_cost"]/frontier_cost)*100 |
| |
| if name in ("oracle","always_frontier","heuristic_diff+1","safe_CARROT_t0.60", |
| "cascade_binary_t0.4","cascade_binary_t0.5","cascade_binary_t0.3", |
| "always_cheap"): |
| print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}") |
|
|
| |
| print("\n\nPARETO FRONTIER:") |
| pareto = [] |
| for name, r in results.items(): |
| if name in ("always_cheap",): |
| continue |
| dominated = False |
| for name2, r2 in results.items(): |
| if name == name2: continue |
| if r2["success"] >= r["success"] and r2["avg_cost"] <= r["avg_cost"]: |
| if r2["success"] > r["success"] or r2["avg_cost"] < r["avg_cost"]: |
| dominated = True; break |
| if not dominated: |
| pareto.append((name, r)) |
| cr = (1 - r["avg_cost"]/frontier_cost)*100 |
| print(f" {name:<25} success={r['success']:.3f} cost={r['avg_cost']:.4f} costRed={cr:.1f}%") |
|
|
| |
| with open("/app/router_models/eval_results.json", "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| print(f"\n Saved eval_results.json") |
| print(f"\nDONE!") |
|
|