narcolepticchicken commited on
Commit
4c6ae13
·
verified ·
1 Parent(s): 12ce4ea

Upload training/aco_eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/aco_eval.py +181 -0
training/aco_eval.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """ACO Benchmark Evaluation: Full system test with simulated agent traces."""
3
+ import sys,json,random,pickle,time
4
+ sys.path.insert(0,"/app")
5
+ from collections import defaultdict
6
+
7
+ TIER_STR={1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
8
+ TIER_COST={1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
9
+ TASK_FLOOR={"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
10
+ "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
11
+ "tool_heavy":2,"retrieval_heavy":2}
12
+
13
+ CODE_KW=["python","javascript","code","function","bug","debug","refactor","implement","test"]
14
+ CRITICAL_KW=["critical","production","urgent","now","emergency","live","deployed","safety","security"]
15
+ SIMPLE_KW=["typo","simple","quick","brief","briefly","just","minor","small","easy","trivial","clarification"]
16
+
17
+ from aco.classifier import TaskCostClassifier
18
+ from aco.router import ModelCascadeRouter
19
+ from aco.context_budgeter import ContextBudgeter
20
+ from aco.tool_gate import ToolCostGate
21
+ from aco.verifier_budgeter import VerifierBudgeter
22
+ from aco.retry_optimizer import RetryOptimizer
23
+ from aco.meta_tool_miner import MetaToolMiner
24
+ from aco.doom_detector import DoomDetector
25
+
26
+ TASKS={
27
+ "quick_answer":["What is 2+2?","Explain quantum computing briefly.","Just tell me what 2+2 is."],
28
+ "coding":["Write a Python function to reverse a linked list.","Fix a typo in the README.","Debug this critical production segfault NOW.","Just fix the typo in line 42."],
29
+ "research":["Research latest transformer advances.","Find sources comparing LoRA and full FT briefly."],
30
+ "document_drafting":["Draft project proposal for ML pipeline.","Write email to team about deployment."],
31
+ "legal_regulated":["Review this contract for liability clauses.","Check GDPR compliance for data pipeline urgently."],
32
+ "tool_heavy":["Search open issues and create summary.","Fetch API docs and generate client code."],
33
+ "retrieval_heavy":["Answer based on 50-page document.","Find all payment processing mentions."],
34
+ "long_horizon":["Plan 3-month roadmap.","Orchestrate complete multi-region deployment."],
35
+ "unknown_ambiguous":["Help me with this thing.","I need something about the server."],
36
+ }
37
+
38
+ TOOL_LIST=["web_search","code_search","file_read","file_write","code_execute","verify"]
39
+ TOOL_COST_ESTIMATES={"web_search":{"cost":0.01},"code_search":{"cost":0.005},"file_read":{"cost":0.001},"file_write":{"cost":0.001},"code_execute":{"cost":0.01},"verify":{"cost":0.02}}
40
+ VERIFIER_COST=0.02
41
+
42
+ print("="*80)
43
+ print("ACO FULL SYSTEM BENCHMARK EVALUATION")
44
+ print("="*80)
45
+
46
+ # Initialize modules
47
+ classifier=TaskCostClassifier()
48
+ router=ModelCascadeRouter(model_path="/app/router_models/router_bundle_v8.pkl")
49
+ context_budgeter=ContextBudgeter()
50
+ tool_gate=ToolCostGate()
51
+ verifier_budgeter=VerifierBudgeter()
52
+ retry_optimizer=RetryOptimizer()
53
+ meta_tool_miner=MetaToolMiner()
54
+ doom_detector=DoomDetector()
55
+
56
+ # Simulate 2000 agent runs
57
+ rng=random.Random(42)
58
+ N=2000
59
+ results_aco=[]
60
+ results_frontier=[]
61
+ results_heuristic=[]
62
+ results_cheap=[]
63
+
64
+ for i in range(N):
65
+ tt=rng.choice(list(TASKS.keys()))
66
+ req=rng.choice(TASKS[tt])
67
+
68
+ # Classify
69
+ pred=classifier.classify(req)
70
+ # Route
71
+ routing=router.route(req, pred["task_type"], pred["difficulty"], pred)
72
+ # Context budget
73
+ budget=context_budgeter.budget(pred["task_type"],pred["difficulty"],pred["needs_retrieval"],pred["needs_tools"])
74
+ # Tool decisions
75
+ tool_decisions={}
76
+ for tool in TOOL_LIST:
77
+ if pred["needs_tools"] or tt in ("coding","tool_heavy","retrieval_heavy","research"):
78
+ td=tool_gate.gate(tool,{"query":req},tt,1,5,routing.confidence)
79
+ tool_decisions[tool]=td
80
+ # Verifier
81
+ vd=verifier_budgeter.should_verify(tt,pred["risk"],routing.confidence,False,False,routing.tier)
82
+ # Simulate success
83
+ ps=TIER_STR[routing.tier]**(pred["difficulty"]*0.6)
84
+ success=rng.random()<ps
85
+ # Compute cost
86
+ model_cost=TIER_COST[routing.tier]
87
+ tool_cost=sum(TOOL_COST_ESTIMATES.get(t,{}).get("cost",0.02) for t,td in tool_decisions.items() if td.action=="use")
88
+ ver_cost=VERIFIER_COST if vd.should_verify else 0
89
+ total_cost=model_cost+tool_cost+ver_cost
90
+
91
+ results_aco.append({"tt":tt,"tier":routing.tier,"success":success,"cost":total_cost,
92
+ "model_cost":model_cost,"tool_cost":tool_cost,"ver_cost":ver_cost,
93
+ "context_tokens":budget.total_tokens,"verified":vd.should_verify,
94
+ "tools_used":sum(1 for td in tool_decisions.values() if td.action=="use"),
95
+ "escalated":routing.escalated,"downgraded":routing.downgraded})
96
+
97
+ # Baseline: always frontier
98
+ ps_f=TIER_STR[4]**(pred["difficulty"]*0.6)
99
+ s_f=rng.random()<ps_f
100
+ results_frontier.append({"tt":tt,"tier":4,"success":s_f,"cost":1.0+tool_cost+VERIFIER_COST})
101
+
102
+ # Baseline: heuristic
103
+ h_tier=min(pred["difficulty"]+1,5)
104
+ h_tier=max(h_tier,TASK_FLOOR.get(tt,2))
105
+ ps_h=TIER_STR[h_tier]**(pred["difficulty"]*0.6)
106
+ s_h=rng.random()<ps_h
107
+ results_heuristic.append({"tt":tt,"tier":h_tier,"success":s_h,"cost":TIER_COST[h_tier]+tool_cost+ver_cost})
108
+
109
+ # Baseline: always cheap
110
+ ps_c=TIER_STR[1]**(pred["difficulty"]*0.6)
111
+ s_c=rng.random()<ps_c
112
+ results_cheap.append({"tt":tt,"tier":1,"success":s_c,"cost":0.05+tool_cost})
113
+
114
+ verifier_budgeter.reset_run()
115
+
116
+
117
+ # Compute metrics
118
+ def compute_metrics(results, name):
119
+ n=len(results)
120
+ succ=sum(1 for r in results if r["success"])
121
+ cost=sum(r["cost"] for r in results)
122
+ model_cost=sum(r.get("model_cost",r["cost"]) for r in results)
123
+ tool_cost=sum(r.get("tool_cost",0) for r in results)
124
+ ver_cost=sum(r.get("ver_cost",0) for r in results)
125
+ ctx=sum(r.get("context_tokens",8000) for r in results)/n
126
+ verified=sum(1 for r in results if r.get("verified",True))
127
+ tools=sum(r.get("tools_used",0) for r in results)/n
128
+ escalations=sum(1 for r in results if r.get("escalated",False))
129
+ downgrades=sum(1 for r in results if r.get("downgraded",False))
130
+ return {"name":name,"success_rate":succ/n,"avg_cost":cost/n,
131
+ "model_cost":model_cost/n,"tool_cost":tool_cost/n,"ver_cost":ver_cost/n,
132
+ "avg_context_tokens":ctx,"verifications":verified,
133
+ "avg_tools":tools,"escalations":escalations,"downgrades":downgrades}
134
+
135
+ m=compute_metrics(results_aco,"aco_v8")
136
+ m_f=compute_metrics(results_frontier,"always_frontier")
137
+ m_h=compute_metrics(results_heuristic,"heuristic")
138
+ m_c=compute_metrics(results_cheap,"always_cheap")
139
+
140
+ print(f"\n{'Router':<20} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'ModelCost':>10} {'ToolCost':>10} {'VerCost':>10} {'Context':>10} {'Verifs':>8}")
141
+ print("-"*100)
142
+ for r in [m_f,m_h,m,m_c]:
143
+ cr=(1-r["avg_cost"]/m_f["avg_cost"])*100
144
+ print(f"{r['name']:<20} {r['success_rate']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['model_cost']:>10.4f} {r['tool_cost']:>10.4f} {r['ver_cost']:>10.4f} {r['avg_context_tokens']:>10.0f} {r['verifications']:>8d}")
145
+
146
+ # Per-task breakdown
147
+ print(f"\n\nPer-task breakdown:")
148
+ for tt in sorted(set(r["tt"] for r in results_aco)):
149
+ aco_tt=[r for r in results_aco if r["tt"]==tt]
150
+ front_tt=[r for r in results_frontier if r["tt"]==tt]
151
+ n_tt=len(aco_tt)
152
+ a_s=sum(1 for r in aco_tt if r["success"])/n_tt
153
+ a_c=sum(r["cost"] for r in aco_tt)/n_tt
154
+ f_c=sum(r["cost"] for r in front_tt)/n_tt
155
+ f_s=sum(1 for r in front_tt if r["success"])/n_tt
156
+ cr=(1-a_c/f_c)*100
157
+ print(f" {tt:<20} n={n_tt:>4} aco_success={a_s:.3f} frontier_success={f_s:.3f} aco_cost={a_c:.4f} costRed={cr:.1f}%")
158
+
159
+ # Cost-quality frontier
160
+ print(f"\n\nCost-Quality Frontier:")
161
+ frontier_points=[]
162
+ for r in [m_c,m_h,m,m_f]:
163
+ frontier_points.append((r["avg_cost"],r["success_rate"],r["name"]))
164
+ frontier_points.sort(key=lambda x:x[0])
165
+ for cost,succ,name in frontier_points:
166
+ print(f" {name:<20} cost={cost:.4f} success={succ:.3f}")
167
+
168
+ # Key findings
169
+ print(f"\n\nKEY FINDINGS:")
170
+ print(f" ACO v8 success rate: {m['success_rate']:.3f}")
171
+ print(f" ACO v8 cost reduction: {(1-m['avg_cost']/m_f['avg_cost'])*100:.1f}%")
172
+ print(f" ACO v8 avg context: {m['avg_context_tokens']:.0f} tokens")
173
+ print(f" ACO v8 verifications: {m['verifications']}/{N}")
174
+ print(f" Escalations: {m['escalations']} ({m['escalations']/N*100:.1f}%)")
175
+ print(f" Downgrades: {m['downgrades']} ({m['downgrades']/N*100:.1f}%)")
176
+
177
+ # Save
178
+ with open("/app/aco_benchmark_results.json","w") as f:
179
+ json.dump({"aco_v8":m,"frontier":m_f,"heuristic":m_h,"cheap":m_c},f,indent=2)
180
+ print(f"\nSaved to /app/aco_benchmark_results.json")
181
+ print("DONE!")