narcolepticchicken commited on
Commit
d3c4fd1
·
verified ·
1 Parent(s): e086ff8

Upload training/benchmark_with_feedback.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/benchmark_with_feedback.py +199 -0
training/benchmark_with_feedback.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Comprehensive benchmark with execution-feedback routing."""
3
+ import sys,json,random,math,pickle,time
4
+ sys.path.insert(0,"/app")
5
+ from collections import defaultdict
6
+
7
+ TIER_STR={1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
8
+ TIER_COST={1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
9
+ TASK_FLOOR={"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
10
+ "unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
11
+ "tool_heavy":2,"retrieval_heavy":2}
12
+
13
+ from aco.classifier import TaskCostClassifier
14
+ from aco.router import ModelCascadeRouter
15
+ from aco.execution_feedback import ExecutionFeedbackRouter, FeedbackSignal
16
+
17
+ TASKS = {
18
+ "quick_answer":["What is 2+2?","Explain quantum computing briefly.",
19
+ "What is the capital of France?","Convert 100F to Celsius.",
20
+ "Small clarification on this formula."],
21
+ "coding":["Write a Python function to reverse a linked list.",
22
+ "Fix a typo in the README.","Debug this critical production segfault NOW.",
23
+ "Just fix the typo in line 42.","Refactor auth module to JWT.",
24
+ "Implement LRU cache in Go."],
25
+ "research":["Research latest transformer advances.",
26
+ "Find sources comparing LoRA and full FT briefly.",
27
+ "Investigate data center climate impact."],
28
+ "document_drafting":["Draft project proposal for ML pipeline.",
29
+ "Write email to team about deployment."],
30
+ "legal_regulated":["Review this contract for liability clauses.",
31
+ "Check GDPR compliance for data pipeline urgently."],
32
+ "tool_heavy":["Search open issues and create summary.",
33
+ "Fetch API docs and generate client code."],
34
+ "retrieval_heavy":["Answer based on 50-page document.",
35
+ "Find all payment processing mentions."],
36
+ "long_horizon":["Plan 3-month roadmap.",
37
+ "Orchestrate complete multi-region deployment."],
38
+ "unknown_ambiguous":["Help me with this thing.",
39
+ "I need something about the server."],
40
+ }
41
+
42
+ print("="*80)
43
+ print("ACO v9 BENCHMARK: EXECUTION-FEEDBACK + DYNAMIC DIFFICULTY")
44
+ print("="*80)
45
+
46
+ classifier = TaskCostClassifier()
47
+ router = ModelCascadeRouter(model_path="/app/router_models/router_bundle_v8.pkl")
48
+ ef_router = ExecutionFeedbackRouter(tier_costs=TIER_COST, task_floors=TASK_FLOOR)
49
+
50
+ rng = random.Random(42)
51
+ N = 3000
52
+
53
+ def sim_logprobs(tier, difficulty, success, rng):
54
+ """Simulate token logprobs based on tier and difficulty."""
55
+ n_tokens = rng.randint(20, 200)
56
+ base_lp = {1:-3.5, 2:-2.5, 3:-1.5, 4:-0.7, 5:-0.3}[tier]
57
+ base_lp *= (1 + difficulty * 0.15)
58
+ lps = []
59
+ for _ in range(n_tokens):
60
+ noise = rng.gauss(0, 1.0 + difficulty * 0.3)
61
+ if success:
62
+ lps.append(base_lp + noise * 0.3)
63
+ else:
64
+ lps.append(base_lp + noise * 0.8)
65
+ return lps
66
+
67
+ def eval_method(name, route_fn):
68
+ succ = 0; cost = 0.0; unsafe = 0
69
+ per_tt = defaultdict(lambda: {"succ":0,"cost":0.0,"n":0})
70
+ for i in range(N):
71
+ tt = rng.choice(list(TASKS.keys()))
72
+ req = rng.choice(TASKS[tt])
73
+ pred = classifier.classify(req)
74
+ tier, s, c, u = route_fn(req, tt, pred)
75
+ if s: succ += 1
76
+ cost += c
77
+ if u: unsafe += 1
78
+ per_tt[tt]["succ"] += (1 if s else 0)
79
+ per_tt[tt]["cost"] += c
80
+ per_tt[tt]["n"] += 1
81
+ return {"name":name,"success":succ/N,"avg_cost":cost/N,"unsafe":unsafe/N,"per_tt":dict(per_tt)}
82
+
83
+ # Method A: always frontier
84
+ def route_frontier(req, tt, pred):
85
+ ps = TIER_STR[4]**(pred["difficulty"]*0.6)
86
+ return 4, rng.random()<ps, TIER_COST[4], False
87
+
88
+ # Method B: heuristic
89
+ def route_heuristic(req, tt, pred):
90
+ h = min(pred["difficulty"]+1,5)
91
+ h = max(h, TASK_FLOOR.get(tt,2))
92
+ ps = TIER_STR[h]**(pred["difficulty"]*0.6)
93
+ return h, rng.random()<ps, TIER_COST[h], (h < 4 and not rng.random()<ps)
94
+
95
+ # Method C: v8 router (no feedback)
96
+ def route_v8(req, tt, pred):
97
+ r = router.route(req, tt, pred["difficulty"], pred)
98
+ ps = TIER_STR[r.tier]**(r.dynamic_difficulty*0.6)
99
+ return r.tier, rng.random()<ps, TIER_COST[r.tier], r.escalated
100
+
101
+ # Method D: v9 = v8 router + execution feedback cascade
102
+ def route_v9(req, tt, pred):
103
+ r = router.route(req, tt, pred["difficulty"], pred)
104
+ initial_tier = r.tier
105
+ ps_initial = TIER_STR[initial_tier]**(r.dynamic_difficulty*0.6)
106
+ initial_success = rng.random() < ps_initial
107
+ lps = sim_logprobs(initial_tier, r.dynamic_difficulty, initial_success, rng)
108
+ signal = ef_router.analyze_output(lps, task_type=tt, current_tier=initial_tier)
109
+ if signal.should_escalate and initial_tier < 5:
110
+ final_tier = min(initial_tier + 1, 5)
111
+ final_tier = max(final_tier, TASK_FLOOR.get(tt, 1))
112
+ ps_final = TIER_STR[final_tier]**(r.dynamic_difficulty*0.6)
113
+ final_success = rng.random() < ps_final
114
+ total_cost = TIER_COST[initial_tier] + TIER_COST[final_tier]
115
+ return final_tier, final_success, total_cost, False
116
+ else:
117
+ return initial_tier, initial_success, TIER_COST[initial_tier], False
118
+
119
+ # Method E: oracle
120
+ def route_oracle(req, tt, pred):
121
+ for t in range(1,6):
122
+ ps = TIER_STR[t]**(pred["difficulty"]*0.6)
123
+ if rng.random() < ps:
124
+ return t, True, TIER_COST[t], False
125
+ return 5, False, TIER_COST[5], False
126
+
127
+ # Method F: always cheap
128
+ def route_cheap(req, tt, pred):
129
+ ps = TIER_STR[1]**(pred["difficulty"]*0.6)
130
+ return 1, rng.random()<ps, TIER_COST[1], (not rng.random()<ps)
131
+
132
+ # Run all methods
133
+ print(f"\n[1] Running {N} simulated traces per method...")
134
+ results = {}
135
+ for name, fn in [("always_frontier",route_frontier),("always_cheap",route_cheap),
136
+ ("heuristic",route_heuristic),("v8_router",route_v8),
137
+ ("v9_feedback",route_v9),("oracle",route_oracle)]:
138
+ rng_state = rng.getstate()
139
+ rng.seed(42)
140
+ results[name] = eval_method(name, fn)
141
+ rng.setstate(rng_state)
142
+
143
+ # Print comparison
144
+ print(f"\n\n{'Method':<20} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Unsafe':>10}")
145
+ print("-"*60)
146
+ fc = results["always_frontier"]["avg_cost"]
147
+ for name in ["oracle","always_frontier","v9_feedback","v8_router","heuristic","always_cheap"]:
148
+ r = results[name]
149
+ cr = (1-r["avg_cost"]/fc)*100
150
+ print(f"{name:<20} {r['success']:>10.3f} {r['avg_cost']:>10.4f} {cr:>9.1f}% {r['unsafe']:>10.3f}")
151
+
152
+ # Per-task comparison
153
+ print(f"\n\n[2] Per-task success rate comparison:")
154
+ print(f"{'Task':<20} {'Frontier':>10} {'v8':>10} {'v9_feedback':>12} {'CostRed_v9':>12}")
155
+ print("-"*65)
156
+ for tt in sorted(set(k for r in results.values() for k in r["per_tt"])):
157
+ f_r = results["always_frontier"]["per_tt"].get(tt,{"succ":0,"n":1,"cost":0})
158
+ v8_r = results["v8_router"]["per_tt"].get(tt,{"succ":0,"n":1,"cost":0})
159
+ v9_r = results["v9_feedback"]["per_tt"].get(tt,{"succ":0,"n":1,"cost":0})
160
+ f_sr = f_r["succ"]/max(f_r["n"],1)
161
+ v8_sr = v8_r["succ"]/max(v8_r["n"],1)
162
+ v9_sr = v9_r["succ"]/max(v9_r["n"],1)
163
+ f_c = f_r["cost"]/max(f_r["n"],1)
164
+ v9_c = v9_r["cost"]/max(v9_r["n"],1)
165
+ cr = (1-v9_c/f_c)*100 if f_c > 0 else 0
166
+ print(f"{tt:<20} {f_sr:>10.3f} {v8_sr:>10.3f} {v9_sr:>12.3f} {cr:>11.1f}%")
167
+
168
+ # Cost-quality frontier
169
+ print(f"\n\n[3] Cost-Quality Frontier:")
170
+ for name in ["always_cheap","v8_router","v9_feedback","heuristic","always_frontier","oracle"]:
171
+ r = results[name]
172
+ cr = (1-r["avg_cost"]/fc)*100
173
+ print(f" {name:<20} success={r['success']:.3f} cost={r['avg_cost']:.4f} costRed={cr:.1f}%")
174
+
175
+ # Key metrics
176
+ v9 = results["v9_feedback"]
177
+ v8 = results["v8_router"]
178
+ fr = results["always_frontier"]
179
+ v9_cr = (1-v9["avg_cost"]/fr["avg_cost"])*100
180
+ v8_cr = (1-v8["avg_cost"]/fr["avg_cost"])*100
181
+ quality_gap_v9 = fr["success"] - v9["success"]
182
+ quality_gap_v8 = fr["success"] - v8["success"]
183
+
184
+ print(f"\n\n[4] KEY RESULTS:")
185
+ print(f" v8 success: {v8['success']:.3f} (gap vs frontier: {quality_gap_v8:.3f})")
186
+ print(f" v9 success: {v9['success']:.3f} (gap vs frontier: {quality_gap_v9:.3f})")
187
+ print(f" v8 costRed: {v8_cr:.1f}%")
188
+ print(f" v9 costRed: {v9_cr:.1f}%")
189
+ print(f" Quality gain v9 vs v8: {v9['success']-v8['success']:+.3f}")
190
+ print(f" Cost increase v9 vs v8: {v9['avg_cost']-v8['avg_cost']:+.4f}")
191
+
192
+ # Save
193
+ with open("/app/benchmark_v9_results.json","w") as f:
194
+ save_data = {}
195
+ for name, r in results.items():
196
+ save_data[name] = {"success":r["success"],"avg_cost":r["avg_cost"],"unsafe":r["unsafe"]}
197
+ json.dump(save_data, f, indent=2)
198
+ print(f"\nSaved to /app/benchmark_v9_results.json")
199
+ print("DONE!")