| |
| """Final Production Router v8: Dynamic difficulty + ML confirmation + safety floors. |
| |
| This is the production router that replaces the heuristic in ACO. |
| """ |
| import json, os, sys, random, uuid, pickle |
| import numpy as np |
| from collections import defaultdict |
|
|
| print("="*80) |
| print("ACO PRODUCTION ROUTER v8: DYNAMIC DIFFICULTY + ML") |
| print("="*80) |
|
|
| |
| MODEL_DIR = "/app/router_models" |
| bundle = pickle.load(open(f"{MODEL_DIR}/router_bundle_v6.pkl", "rb")) |
| feat_keys = bundle["feat_keys"] |
| tier_clfs = {int(k):v for k,v in bundle["tier_clfs"].items()} |
| tier_calibs = {int(k):v for k,v in bundle["tier_calibrators"].items()} |
| TIER_COST = {int(k):v for k,v in bundle["tier_config"]["tier_cost"].items()} |
| TIER_STR = {int(k):v for k,v in bundle["tier_config"]["tier_str"].items()} |
| TASK_FLOOR = bundle["tier_config"]["task_floor"] |
|
|
| |
| CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test", |
| "compile","runtime","segfault","thread","async","class","module"] |
| LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"] |
| RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"] |
| TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"] |
| LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"] |
| MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"] |
| CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"] |
| SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"] |
| TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3, |
| "legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8} |
|
|
| def estimate_difficulty(request, task_type): |
| r = request.lower() |
| base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, |
| "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[task_type] |
| if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5) |
| if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1) |
| return base |
|
|
| def extract_features(request, task_type, difficulty=3): |
| r = request.lower() |
| f = {"req_len":len(request),"num_words":len(request.split()), |
| "has_code":int(any(k in r for k in CODE_KW)),"n_code":sum(1 for k in CODE_KW if k in r), |
| "has_legal":int(any(k in r for k in LEGAL_KW)),"n_legal":sum(1 for k in LEGAL_KW if k in r), |
| "has_research":int(any(k in r for k in RESEARCH_KW)),"n_research":sum(1 for k in RESEARCH_KW if k in r), |
| "has_tool":int(any(k in r for k in TOOL_KW)),"n_tool":sum(1 for k in TOOL_KW if k in r), |
| "has_long":int(any(k in r for k in LONG_KW)), |
| "has_math":int(any(k in r for k in MATH_KW)), |
| "tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty} |
| for tt in TT2IDX: |
| f[f"tt_{tt}"] = int(task_type == tt) |
| return f |
|
|
| def f2v(feats): |
| return np.array([float(feats.get(k, 0.0)) for k in feat_keys], dtype=np.float32) |
|
|
| def get_calibrated_psuccess(x, tier): |
| p_raw = tier_clfs[tier].predict_proba(x)[0, 1] |
| return float(tier_calibs[tier].transform([p_raw])[0]) |
|
|
| def route_production_v8(request, task_type, safety=0.30, downgrade=0.90): |
| diff = estimate_difficulty(request, task_type) |
| base = min(diff + 1, 5) |
| floor = TASK_FLOOR.get(task_type, 2) |
| base = max(base, floor) |
| feats = extract_features(request, task_type, diff) |
| x = f2v(feats).reshape(1, -1) |
| tier = base |
| ps = get_calibrated_psuccess(x, tier) |
| |
| if ps < safety and tier < 5: |
| tier += 1 |
| ps = get_calibrated_psuccess(x, tier) |
| |
| if tier > floor and tier == base: |
| cheaper = tier - 1 |
| pc = get_calibrated_psuccess(x, cheaper) |
| if pc >= downgrade and cheaper >= floor: |
| tier = cheaper |
| ps = pc |
| return tier, ps, diff |
|
|
| |
| TASK_TEMPLATES = { |
| "quick_answer":["What is the capital of France?","Explain quantum computing briefly.", |
| "What is 237*452?","Briefly explain photosynthesis.","Just tell me what 2+2 is.", |
| "Small clarification on this formula."], |
| "coding":["Write a Python function to reverse a linked list.", |
| "Fix the bug in this React component.","Refactor auth module to JWT.", |
| "Implement LRU cache in Go.","Debug segfault in C++ thread pool.", |
| "Fix a typo in the README.","Debug this critical production segfault NOW.", |
| "Just fix the typo in line 42."], |
| "research":["Research latest transformer advances.", |
| "Find sources comparing LoRA and full FT.", |
| "Investigate data center climate impact.", |
| "Find sources comparing LoRA and full FT briefly."], |
| "document_drafting":["Draft project proposal for ML pipeline.", |
| "Write email to team about deployment.","Create technical report on performance."], |
| "legal_regulated":["Review this contract for liability clauses.", |
| "Check GDPR compliance for data pipeline.","Draft privacy policy section.", |
| "Check GDPR compliance urgently."], |
| "tool_heavy":["Search open issues and create summary.", |
| "Fetch API docs and generate client code.","Query Q3 sales and produce chart."], |
| "retrieval_heavy":["Answer based on 50-page document.", |
| "Find all payment processing mentions."], |
| "long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.", |
| "Redesign data architecture end-to-end.", |
| "Orchestrate complete multi-region deployment."], |
| "unknown_ambiguous":["Help me with this thing.", |
| "I need something about the server."], |
| } |
|
|
| def tsp(tier, diff): |
| return TIER_STR[tier] ** (diff * 0.6) |
|
|
| print("\n[1] Generating 2K eval traces...") |
| rng = random.Random(999) |
| traces = [] |
| for i in range(2000): |
| tt = rng.choice(list(TASK_TEMPLATES.keys())) |
| |
| static_diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2, |
| "research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt] |
| req = rng.choice(TASK_TEMPLATES[tt]) |
| |
| dyn_diff = estimate_difficulty(req, tt) |
| |
| tier_out = {t: rng.random() < tsp(t, dyn_diff) for t in range(1,6)} |
| opt = 5 |
| for t in range(1,6): |
| if tier_out[t]: opt = t; break |
| traces.append({"tt":tt,"static_diff":static_diff,"dyn_diff":dyn_diff, |
| "opt":opt,"tier_out":tier_out,"req":req}) |
|
|
| print(f" Generated {len(traces)} traces") |
|
|
| |
| print("\n[2] Evaluating all routers...") |
| n = len(traces) |
|
|
| def eval_router(name, route_fn): |
| succ=0; cost=0.0; unsafe=0; fd=0; td=defaultdict(int) |
| for t in traces: |
| pred = route_fn(t) |
| td[pred] += 1 |
| if t["tier_out"].get(pred, False): succ += 1 |
| elif pred < t["opt"]: unsafe += 1 |
| else: fd += 1 |
| cost += TIER_COST[pred] |
| return {"success":succ/n,"avg_cost":cost/n,"unsafe_rate":unsafe/n, |
| "false_done":fd/n,"tier_dist":dict(td)} |
|
|
| results = {} |
| results["always_frontier"] = eval_router("always_frontier", lambda t: 4) |
| results["always_cheap"] = eval_router("always_cheap", lambda t: 1) |
| results["heuristic_static"] = eval_router("heuristic_static", |
| lambda t: max(min(t["static_diff"]+1,5), TASK_FLOOR.get(t["tt"],2))) |
| results["oracle"] = eval_router("oracle", lambda t: t["opt"]) |
|
|
| |
| results["v8_dynamic+ML"] = eval_router("v8_dynamic+ML", |
| lambda t: route_production_v8(t["req"], t["tt"])[0]) |
|
|
| |
| results["v8_dynamic_only"] = eval_router("v8_dynamic_only", |
| lambda t: max(min(t["dyn_diff"]+1,5), TASK_FLOOR.get(t["tt"],2))) |
|
|
| |
| print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}") |
| print("-"*75) |
| fc = results["always_frontier"]["avg_cost"] |
| for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])): |
| cr = (1-r["avg_cost"]/fc)*100 |
| print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}") |
|
|
| |
| print(f"\n\n[3] Per-task breakdown...") |
| for tt in sorted(set(t["tt"] for t in traces)): |
| tt_r = [t for t in traces if t["tt"] == tt] |
| n_tt = len(tt_r) |
| print(f"\n {tt} (n={n_tt}):") |
| for rname, rfn in [("frontier", lambda t:4), |
| ("heuristic", lambda t:max(min(t["static_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))), |
| ("v8_dynamic", lambda t:max(min(t["dyn_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))), |
| ("v8_full", lambda t:route_production_v8(t["req"],t["tt"])[0]), |
| ("oracle", lambda t:t["opt"])]: |
| succ = sum(1 for t in tt_r if t["tier_out"].get(rfn(t), False)) |
| cost = sum(TIER_COST[rfn(t)] for t in tt_r) |
| sr = succ/n_tt; ac = cost/n_tt |
| cr = (1-ac/fc)*100 |
| print(f" {rname:<14} success={sr:.3f} cost={ac:.4f} costRed={cr:.1f}%") |
|
|
| |
| with open("/app/router_models/v8_final_results.json","w") as f: |
| json.dump(results, f, indent=2, default=str) |
|
|
| |
| v8_bundle = { |
| "tier_clfs": {str(k):v for k,v in tier_clfs.items()}, |
| "tier_calibrators": {str(k):v for k,v in tier_calibs.items()}, |
| "feat_keys": feat_keys, |
| "tier_config": {str(k):v for k,v in TIER_COST.items()}, |
| "task_floor": TASK_FLOOR, |
| "version": "8.0", |
| "description": "ACO Production Router v8: dynamic difficulty + ML confirmation + safety floors", |
| "dynamic_difficulty": True, |
| "critical_keywords": CRITICAL_KW, |
| "simple_keywords": SIMPLE_KW, |
| } |
| with open("/app/router_models/router_bundle_v8.pkl","wb") as f: |
| pickle.dump(v8_bundle, f) |
|
|
| print(f"\n\n{'='*80}") |
| print("FINAL v8 RESULTS") |
| print(f"{'='*80}") |
| print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10}") |
| print("-"*65) |
| for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])): |
| cr = (1-r["avg_cost"]/fc)*100 |
| print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f}") |
|
|
| print(f"\nSaved router_bundle_v8.pkl ({os.path.getsize('/app/router_models/router_bundle_v8.pkl')/1024:.0f} KB)") |
| print(f"DONE!") |
|
|