agent-cost-optimizer / training /router_v8_final.py
narcolepticchicken's picture
Upload training/router_v8_final.py with huggingface_hub
12ce4ea verified
#!/usr/bin/env python3
"""Final Production Router v8: Dynamic difficulty + ML confirmation + safety floors.
This is the production router that replaces the heuristic in ACO.
"""
import json, os, sys, random, uuid, pickle
import numpy as np
from collections import defaultdict
print("="*80)
print("ACO PRODUCTION ROUTER v8: DYNAMIC DIFFICULTY + ML")
print("="*80)
# ─── Load Models ──────────────────────────────────────────────────────
MODEL_DIR = "/app/router_models"
bundle = pickle.load(open(f"{MODEL_DIR}/router_bundle_v6.pkl", "rb"))
feat_keys = bundle["feat_keys"]
tier_clfs = {int(k):v for k,v in bundle["tier_clfs"].items()}
tier_calibs = {int(k):v for k,v in bundle["tier_calibrators"].items()}
TIER_COST = {int(k):v for k,v in bundle["tier_config"]["tier_cost"].items()}
TIER_STR = {int(k):v for k,v in bundle["tier_config"]["tier_str"].items()}
TASK_FLOOR = bundle["tier_config"]["task_floor"]
# ─── Feature Extraction ────────────────────────────────────────────────
CODE_KW = ["python","javascript","code","function","bug","debug","refactor","implement","test",
"compile","runtime","segfault","thread","async","class","module"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy","regulatory","liability","indemnification","clause"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare","analyze","survey","paper","arxiv"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate","pipeline","deploy","architecture"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize","probability","integral"]
CRITICAL_KW = ["critical","production","urgent","now","emergency","live","deployed","safety","security"]
SIMPLE_KW = ["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
TT2IDX = {"quick_answer":0,"coding":1,"research":2,"document_drafting":3,
"legal_regulated":4,"tool_heavy":5,"retrieval_heavy":6,"long_horizon":7,"unknown_ambiguous":8}
def estimate_difficulty(request, task_type):
r = request.lower()
base = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
"research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[task_type]
if any(k in r for k in CRITICAL_KW): base = min(base + 1, 5)
if any(k in r for k in SIMPLE_KW): base = max(base - 1, 1)
return base
def extract_features(request, task_type, difficulty=3):
r = request.lower()
f = {"req_len":len(request),"num_words":len(request.split()),
"has_code":int(any(k in r for k in CODE_KW)),"n_code":sum(1 for k in CODE_KW if k in r),
"has_legal":int(any(k in r for k in LEGAL_KW)),"n_legal":sum(1 for k in LEGAL_KW if k in r),
"has_research":int(any(k in r for k in RESEARCH_KW)),"n_research":sum(1 for k in RESEARCH_KW if k in r),
"has_tool":int(any(k in r for k in TOOL_KW)),"n_tool":sum(1 for k in TOOL_KW if k in r),
"has_long":int(any(k in r for k in LONG_KW)),
"has_math":int(any(k in r for k in MATH_KW)),
"tt_idx":TT2IDX.get(task_type,8),"difficulty":difficulty}
for tt in TT2IDX:
f[f"tt_{tt}"] = int(task_type == tt)
return f
def f2v(feats):
return np.array([float(feats.get(k, 0.0)) for k in feat_keys], dtype=np.float32)
def get_calibrated_psuccess(x, tier):
p_raw = tier_clfs[tier].predict_proba(x)[0, 1]
return float(tier_calibs[tier].transform([p_raw])[0])
def route_production_v8(request, task_type, safety=0.30, downgrade=0.90):
diff = estimate_difficulty(request, task_type)
base = min(diff + 1, 5)
floor = TASK_FLOOR.get(task_type, 2)
base = max(base, floor)
feats = extract_features(request, task_type, diff)
x = f2v(feats).reshape(1, -1)
tier = base
ps = get_calibrated_psuccess(x, tier)
# Safety net
if ps < safety and tier < 5:
tier += 1
ps = get_calibrated_psuccess(x, tier)
# Cost saver
if tier > floor and tier == base:
cheaper = tier - 1
pc = get_calibrated_psuccess(x, cheaper)
if pc >= downgrade and cheaper >= floor:
tier = cheaper
ps = pc
return tier, ps, diff
# ─── Generate Eval Traces ────────────────────────────────────────────
TASK_TEMPLATES = {
"quick_answer":["What is the capital of France?","Explain quantum computing briefly.",
"What is 237*452?","Briefly explain photosynthesis.","Just tell me what 2+2 is.",
"Small clarification on this formula."],
"coding":["Write a Python function to reverse a linked list.",
"Fix the bug in this React component.","Refactor auth module to JWT.",
"Implement LRU cache in Go.","Debug segfault in C++ thread pool.",
"Fix a typo in the README.","Debug this critical production segfault NOW.",
"Just fix the typo in line 42."],
"research":["Research latest transformer advances.",
"Find sources comparing LoRA and full FT.",
"Investigate data center climate impact.",
"Find sources comparing LoRA and full FT briefly."],
"document_drafting":["Draft project proposal for ML pipeline.",
"Write email to team about deployment.","Create technical report on performance."],
"legal_regulated":["Review this contract for liability clauses.",
"Check GDPR compliance for data pipeline.","Draft privacy policy section.",
"Check GDPR compliance urgently."],
"tool_heavy":["Search open issues and create summary.",
"Fetch API docs and generate client code.","Query Q3 sales and produce chart."],
"retrieval_heavy":["Answer based on 50-page document.",
"Find all payment processing mentions."],
"long_horizon":["Plan 3-month roadmap.","Orchestrate multi-region deployment.",
"Redesign data architecture end-to-end.",
"Orchestrate complete multi-region deployment."],
"unknown_ambiguous":["Help me with this thing.",
"I need something about the server."],
}
def tsp(tier, diff):
return TIER_STR[tier] ** (diff * 0.6)
print("\n[1] Generating 2K eval traces...")
rng = random.Random(999)
traces = []
for i in range(2000):
tt = rng.choice(list(TASK_TEMPLATES.keys()))
# Use STATIC difficulty for ground truth (same as heuristic)
static_diff = {"quick_answer":1,"document_drafting":2,"tool_heavy":2,"retrieval_heavy":2,
"research":3,"coding":3,"unknown_ambiguous":3,"long_horizon":4,"legal_regulated":5}[tt]
req = rng.choice(TASK_TEMPLATES[tt])
# Dynamic difficulty from request text
dyn_diff = estimate_difficulty(req, tt)
tier_out = {t: rng.random() < tsp(t, dyn_diff) for t in range(1,6)}
opt = 5
for t in range(1,6):
if tier_out[t]: opt = t; break
traces.append({"tt":tt,"static_diff":static_diff,"dyn_diff":dyn_diff,
"opt":opt,"tier_out":tier_out,"req":req})
print(f" Generated {len(traces)} traces")
# ─── Evaluate ──────────────────────────────────────────────────────────
print("\n[2] Evaluating all routers...")
n = len(traces)
def eval_router(name, route_fn):
succ=0; cost=0.0; unsafe=0; fd=0; td=defaultdict(int)
for t in traces:
pred = route_fn(t)
td[pred] += 1
if t["tier_out"].get(pred, False): succ += 1
elif pred < t["opt"]: unsafe += 1
else: fd += 1
cost += TIER_COST[pred]
return {"success":succ/n,"avg_cost":cost/n,"unsafe_rate":unsafe/n,
"false_done":fd/n,"tier_dist":dict(td)}
results = {}
results["always_frontier"] = eval_router("always_frontier", lambda t: 4)
results["always_cheap"] = eval_router("always_cheap", lambda t: 1)
results["heuristic_static"] = eval_router("heuristic_static",
lambda t: max(min(t["static_diff"]+1,5), TASK_FLOOR.get(t["tt"],2)))
results["oracle"] = eval_router("oracle", lambda t: t["opt"])
# v8 production router
results["v8_dynamic+ML"] = eval_router("v8_dynamic+ML",
lambda t: route_production_v8(t["req"], t["tt"])[0])
# v8 without ML (just dynamic difficulty)
results["v8_dynamic_only"] = eval_router("v8_dynamic_only",
lambda t: max(min(t["dyn_diff"]+1,5), TASK_FLOOR.get(t["tt"],2)))
# Print
print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10} {'F-DONE':>10}")
print("-"*75)
fc = results["always_frontier"]["avg_cost"]
for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
cr = (1-r["avg_cost"]/fc)*100
print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f} {r['false_done']:>10.3f}")
# Per-task breakdown
print(f"\n\n[3] Per-task breakdown...")
for tt in sorted(set(t["tt"] for t in traces)):
tt_r = [t for t in traces if t["tt"] == tt]
n_tt = len(tt_r)
print(f"\n {tt} (n={n_tt}):")
for rname, rfn in [("frontier", lambda t:4),
("heuristic", lambda t:max(min(t["static_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))),
("v8_dynamic", lambda t:max(min(t["dyn_diff"]+1,5),TASK_FLOOR.get(t["tt"],2))),
("v8_full", lambda t:route_production_v8(t["req"],t["tt"])[0]),
("oracle", lambda t:t["opt"])]:
succ = sum(1 for t in tt_r if t["tier_out"].get(rfn(t), False))
cost = sum(TIER_COST[rfn(t)] for t in tt_r)
sr = succ/n_tt; ac = cost/n_tt
cr = (1-ac/fc)*100
print(f" {rname:<14} success={sr:.3f} cost={ac:.4f} costRed={cr:.1f}%")
# Save
with open("/app/router_models/v8_final_results.json","w") as f:
json.dump(results, f, indent=2, default=str)
# Save v8 bundle
v8_bundle = {
"tier_clfs": {str(k):v for k,v in tier_clfs.items()},
"tier_calibrators": {str(k):v for k,v in tier_calibs.items()},
"feat_keys": feat_keys,
"tier_config": {str(k):v for k,v in TIER_COST.items()},
"task_floor": TASK_FLOOR,
"version": "8.0",
"description": "ACO Production Router v8: dynamic difficulty + ML confirmation + safety floors",
"dynamic_difficulty": True,
"critical_keywords": CRITICAL_KW,
"simple_keywords": SIMPLE_KW,
}
with open("/app/router_models/router_bundle_v8.pkl","wb") as f:
pickle.dump(v8_bundle, f)
print(f"\n\n{'='*80}")
print("FINAL v8 RESULTS")
print(f"{'='*80}")
print(f"\n{'Router':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10}")
print("-"*65)
for name, r in sorted(results.items(), key=lambda x: (-x[1]["success"], x[1]["avg_cost"])):
cr = (1-r["avg_cost"]/fc)*100
print(f"{name:<25} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe_rate']:>10.3f}")
print(f"\nSaved router_bundle_v8.pkl ({os.path.getsize('/app/router_models/router_bundle_v8.pkl')/1024:.0f} KB)")
print(f"DONE!")