Upload training/aco_eval.py with huggingface_hub
Browse files- training/aco_eval.py +181 -0
training/aco_eval.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""ACO Benchmark Evaluation: Full system test with simulated agent traces."""
|
| 3 |
+
import sys,json,random,pickle,time
|
| 4 |
+
sys.path.insert(0,"/app")
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
TIER_STR={1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
|
| 8 |
+
TIER_COST={1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
|
| 9 |
+
TASK_FLOOR={"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
|
| 10 |
+
"unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
|
| 11 |
+
"tool_heavy":2,"retrieval_heavy":2}
|
| 12 |
+
|
| 13 |
+
CODE_KW=["python","javascript","code","function","bug","debug","refactor","implement","test"]
|
| 14 |
+
CRITICAL_KW=["critical","production","urgent","now","emergency","live","deployed","safety","security"]
|
| 15 |
+
SIMPLE_KW=["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
|
| 16 |
+
|
| 17 |
+
from aco.classifier import TaskCostClassifier
|
| 18 |
+
from aco.router import ModelCascadeRouter
|
| 19 |
+
from aco.context_budgeter import ContextBudgeter
|
| 20 |
+
from aco.tool_gate import ToolCostGate
|
| 21 |
+
from aco.verifier_budgeter import VerifierBudgeter
|
| 22 |
+
from aco.retry_optimizer import RetryOptimizer
|
| 23 |
+
from aco.meta_tool_miner import MetaToolMiner
|
| 24 |
+
from aco.doom_detector import DoomDetector
|
| 25 |
+
|
| 26 |
+
TASKS={
|
| 27 |
+
"quick_answer":["What is 2+2?","Explain quantum computing briefly.","Just tell me what 2+2 is."],
|
| 28 |
+
"coding":["Write a Python function to reverse a linked list.","Fix a typo in the README.","Debug this critical production segfault NOW.","Just fix the typo in line 42."],
|
| 29 |
+
"research":["Research latest transformer advances.","Find sources comparing LoRA and full FT briefly."],
|
| 30 |
+
"document_drafting":["Draft project proposal for ML pipeline.","Write email to team about deployment."],
|
| 31 |
+
"legal_regulated":["Review this contract for liability clauses.","Check GDPR compliance for data pipeline urgently."],
|
| 32 |
+
"tool_heavy":["Search open issues and create summary.","Fetch API docs and generate client code."],
|
| 33 |
+
"retrieval_heavy":["Answer based on 50-page document.","Find all payment processing mentions."],
|
| 34 |
+
"long_horizon":["Plan 3-month roadmap.","Orchestrate complete multi-region deployment."],
|
| 35 |
+
"unknown_ambiguous":["Help me with this thing.","I need something about the server."],
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
TOOL_LIST=["web_search","code_search","file_read","file_write","code_execute","verify"]
|
| 39 |
+
TOOL_COST_ESTIMATES={"web_search":{"cost":0.01},"code_search":{"cost":0.005},"file_read":{"cost":0.001},"file_write":{"cost":0.001},"code_execute":{"cost":0.01},"verify":{"cost":0.02}}
|
| 40 |
+
VERIFIER_COST=0.02
|
| 41 |
+
|
| 42 |
+
print("="*80)
|
| 43 |
+
print("ACO FULL SYSTEM BENCHMARK EVALUATION")
|
| 44 |
+
print("="*80)
|
| 45 |
+
|
| 46 |
+
# Initialize modules
|
| 47 |
+
classifier=TaskCostClassifier()
|
| 48 |
+
router=ModelCascadeRouter(model_path="/app/router_models/router_bundle_v8.pkl")
|
| 49 |
+
context_budgeter=ContextBudgeter()
|
| 50 |
+
tool_gate=ToolCostGate()
|
| 51 |
+
verifier_budgeter=VerifierBudgeter()
|
| 52 |
+
retry_optimizer=RetryOptimizer()
|
| 53 |
+
meta_tool_miner=MetaToolMiner()
|
| 54 |
+
doom_detector=DoomDetector()
|
| 55 |
+
|
| 56 |
+
# Simulate 2000 agent runs
|
| 57 |
+
rng=random.Random(42)
|
| 58 |
+
N=2000
|
| 59 |
+
results_aco=[]
|
| 60 |
+
results_frontier=[]
|
| 61 |
+
results_heuristic=[]
|
| 62 |
+
results_cheap=[]
|
| 63 |
+
|
| 64 |
+
for i in range(N):
|
| 65 |
+
tt=rng.choice(list(TASKS.keys()))
|
| 66 |
+
req=rng.choice(TASKS[tt])
|
| 67 |
+
|
| 68 |
+
# Classify
|
| 69 |
+
pred=classifier.classify(req)
|
| 70 |
+
# Route
|
| 71 |
+
routing=router.route(req, pred["task_type"], pred["difficulty"], pred)
|
| 72 |
+
# Context budget
|
| 73 |
+
budget=context_budgeter.budget(pred["task_type"],pred["difficulty"],pred["needs_retrieval"],pred["needs_tools"])
|
| 74 |
+
# Tool decisions
|
| 75 |
+
tool_decisions={}
|
| 76 |
+
for tool in TOOL_LIST:
|
| 77 |
+
if pred["needs_tools"] or tt in ("coding","tool_heavy","retrieval_heavy","research"):
|
| 78 |
+
td=tool_gate.gate(tool,{"query":req},tt,1,5,routing.confidence)
|
| 79 |
+
tool_decisions[tool]=td
|
| 80 |
+
# Verifier
|
| 81 |
+
vd=verifier_budgeter.should_verify(tt,pred["risk"],routing.confidence,False,False,routing.tier)
|
| 82 |
+
# Simulate success
|
| 83 |
+
ps=TIER_STR[routing.tier]**(pred["difficulty"]*0.6)
|
| 84 |
+
success=rng.random()<ps
|
| 85 |
+
# Compute cost
|
| 86 |
+
model_cost=TIER_COST[routing.tier]
|
| 87 |
+
tool_cost=sum(TOOL_COST_ESTIMATES.get(t,{}).get("cost",0.02) for t,td in tool_decisions.items() if td.action=="use")
|
| 88 |
+
ver_cost=VERIFIER_COST if vd.should_verify else 0
|
| 89 |
+
total_cost=model_cost+tool_cost+ver_cost
|
| 90 |
+
|
| 91 |
+
results_aco.append({"tt":tt,"tier":routing.tier,"success":success,"cost":total_cost,
|
| 92 |
+
"model_cost":model_cost,"tool_cost":tool_cost,"ver_cost":ver_cost,
|
| 93 |
+
"context_tokens":budget.total_tokens,"verified":vd.should_verify,
|
| 94 |
+
"tools_used":sum(1 for td in tool_decisions.values() if td.action=="use"),
|
| 95 |
+
"escalated":routing.escalated,"downgraded":routing.downgraded})
|
| 96 |
+
|
| 97 |
+
# Baseline: always frontier
|
| 98 |
+
ps_f=TIER_STR[4]**(pred["difficulty"]*0.6)
|
| 99 |
+
s_f=rng.random()<ps_f
|
| 100 |
+
results_frontier.append({"tt":tt,"tier":4,"success":s_f,"cost":1.0+tool_cost+VERIFIER_COST})
|
| 101 |
+
|
| 102 |
+
# Baseline: heuristic
|
| 103 |
+
h_tier=min(pred["difficulty"]+1,5)
|
| 104 |
+
h_tier=max(h_tier,TASK_FLOOR.get(tt,2))
|
| 105 |
+
ps_h=TIER_STR[h_tier]**(pred["difficulty"]*0.6)
|
| 106 |
+
s_h=rng.random()<ps_h
|
| 107 |
+
results_heuristic.append({"tt":tt,"tier":h_tier,"success":s_h,"cost":TIER_COST[h_tier]+tool_cost+ver_cost})
|
| 108 |
+
|
| 109 |
+
# Baseline: always cheap
|
| 110 |
+
ps_c=TIER_STR[1]**(pred["difficulty"]*0.6)
|
| 111 |
+
s_c=rng.random()<ps_c
|
| 112 |
+
results_cheap.append({"tt":tt,"tier":1,"success":s_c,"cost":0.05+tool_cost})
|
| 113 |
+
|
| 114 |
+
verifier_budgeter.reset_run()
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Compute metrics
|
| 118 |
+
def compute_metrics(results, name):
|
| 119 |
+
n=len(results)
|
| 120 |
+
succ=sum(1 for r in results if r["success"])
|
| 121 |
+
cost=sum(r["cost"] for r in results)
|
| 122 |
+
model_cost=sum(r.get("model_cost",r["cost"]) for r in results)
|
| 123 |
+
tool_cost=sum(r.get("tool_cost",0) for r in results)
|
| 124 |
+
ver_cost=sum(r.get("ver_cost",0) for r in results)
|
| 125 |
+
ctx=sum(r.get("context_tokens",8000) for r in results)/n
|
| 126 |
+
verified=sum(1 for r in results if r.get("verified",True))
|
| 127 |
+
tools=sum(r.get("tools_used",0) for r in results)/n
|
| 128 |
+
escalations=sum(1 for r in results if r.get("escalated",False))
|
| 129 |
+
downgrades=sum(1 for r in results if r.get("downgraded",False))
|
| 130 |
+
return {"name":name,"success_rate":succ/n,"avg_cost":cost/n,
|
| 131 |
+
"model_cost":model_cost/n,"tool_cost":tool_cost/n,"ver_cost":ver_cost/n,
|
| 132 |
+
"avg_context_tokens":ctx,"verifications":verified,
|
| 133 |
+
"avg_tools":tools,"escalations":escalations,"downgrades":downgrades}
|
| 134 |
+
|
| 135 |
+
m=compute_metrics(results_aco,"aco_v8")
|
| 136 |
+
m_f=compute_metrics(results_frontier,"always_frontier")
|
| 137 |
+
m_h=compute_metrics(results_heuristic,"heuristic")
|
| 138 |
+
m_c=compute_metrics(results_cheap,"always_cheap")
|
| 139 |
+
|
| 140 |
+
print(f"\n{'Router':<20} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'ModelCost':>10} {'ToolCost':>10} {'VerCost':>10} {'Context':>10} {'Verifs':>8}")
|
| 141 |
+
print("-"*100)
|
| 142 |
+
for r in [m_f,m_h,m,m_c]:
|
| 143 |
+
cr=(1-r["avg_cost"]/m_f["avg_cost"])*100
|
| 144 |
+
print(f"{r['name']:<20} {r['success_rate']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['model_cost']:>10.4f} {r['tool_cost']:>10.4f} {r['ver_cost']:>10.4f} {r['avg_context_tokens']:>10.0f} {r['verifications']:>8d}")
|
| 145 |
+
|
| 146 |
+
# Per-task breakdown
|
| 147 |
+
print(f"\n\nPer-task breakdown:")
|
| 148 |
+
for tt in sorted(set(r["tt"] for r in results_aco)):
|
| 149 |
+
aco_tt=[r for r in results_aco if r["tt"]==tt]
|
| 150 |
+
front_tt=[r for r in results_frontier if r["tt"]==tt]
|
| 151 |
+
n_tt=len(aco_tt)
|
| 152 |
+
a_s=sum(1 for r in aco_tt if r["success"])/n_tt
|
| 153 |
+
a_c=sum(r["cost"] for r in aco_tt)/n_tt
|
| 154 |
+
f_c=sum(r["cost"] for r in front_tt)/n_tt
|
| 155 |
+
f_s=sum(1 for r in front_tt if r["success"])/n_tt
|
| 156 |
+
cr=(1-a_c/f_c)*100
|
| 157 |
+
print(f" {tt:<20} n={n_tt:>4} aco_success={a_s:.3f} frontier_success={f_s:.3f} aco_cost={a_c:.4f} costRed={cr:.1f}%")
|
| 158 |
+
|
| 159 |
+
# Cost-quality frontier
|
| 160 |
+
print(f"\n\nCost-Quality Frontier:")
|
| 161 |
+
frontier_points=[]
|
| 162 |
+
for r in [m_c,m_h,m,m_f]:
|
| 163 |
+
frontier_points.append((r["avg_cost"],r["success_rate"],r["name"]))
|
| 164 |
+
frontier_points.sort(key=lambda x:x[0])
|
| 165 |
+
for cost,succ,name in frontier_points:
|
| 166 |
+
print(f" {name:<20} cost={cost:.4f} success={succ:.3f}")
|
| 167 |
+
|
| 168 |
+
# Key findings
|
| 169 |
+
print(f"\n\nKEY FINDINGS:")
|
| 170 |
+
print(f" ACO v8 success rate: {m['success_rate']:.3f}")
|
| 171 |
+
print(f" ACO v8 cost reduction: {(1-m['avg_cost']/m_f['avg_cost'])*100:.1f}%")
|
| 172 |
+
print(f" ACO v8 avg context: {m['avg_context_tokens']:.0f} tokens")
|
| 173 |
+
print(f" ACO v8 verifications: {m['verifications']}/{N}")
|
| 174 |
+
print(f" Escalations: {m['escalations']} ({m['escalations']/N*100:.1f}%)")
|
| 175 |
+
print(f" Downgrades: {m['downgrades']} ({m['downgrades']/N*100:.1f}%)")
|
| 176 |
+
|
| 177 |
+
# Save
|
| 178 |
+
with open("/app/aco_benchmark_results.json","w") as f:
|
| 179 |
+
json.dump({"aco_v8":m,"frontier":m_f,"heuristic":m_h,"cheap":m_c},f,indent=2)
|
| 180 |
+
print(f"\nSaved to /app/aco_benchmark_results.json")
|
| 181 |
+
print("DONE!")
|