Upload training/benchmark_final.py with huggingface_hub
Browse files- training/benchmark_final.py +143 -0
training/benchmark_final.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""FINAL v9 benchmark: v8 router + selective execution feedback.
|
| 3 |
+
|
| 4 |
+
Strategy: Use v8 router for initial tier selection. Then for tasks
|
| 5 |
+
where v8 chose tier < 4 (not frontier), run cheap model first and
|
| 6 |
+
use output confidence to decide whether to escalate. This catches
|
| 7 |
+
the ~2% of cases where cheap model fails but could be saved by
|
| 8 |
+
escalation, without running double-inference everywhere.
|
| 9 |
+
"""
|
| 10 |
+
import sys,random,math,pickle,json
|
| 11 |
+
sys.path.insert(0,"/app")
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
from aco.classifier import TaskCostClassifier
|
| 14 |
+
from aco.router import ModelCascadeRouter
|
| 15 |
+
from aco.execution_feedback import ExecutionFeedbackRouter
|
| 16 |
+
|
| 17 |
+
TIER_STR={1:0.35,2:0.55,3:0.80,4:0.93,5:0.97}
|
| 18 |
+
TIER_COST={1:0.05,2:0.15,3:0.75,4:1.0,5:1.5}
|
| 19 |
+
TASK_FLOOR={"legal_regulated":4,"long_horizon":3,"research":3,"coding":3,
|
| 20 |
+
"unknown_ambiguous":3,"quick_answer":1,"document_drafting":2,
|
| 21 |
+
"tool_heavy":2,"retrieval_heavy":2}
|
| 22 |
+
|
| 23 |
+
TASKS = {
|
| 24 |
+
"quick_answer":["What is 2+2?","Explain quantum computing briefly.","Convert 100F to Celsius."],
|
| 25 |
+
"coding":["Write Python function to reverse linked list.","Fix typo in README.",
|
| 26 |
+
"Debug critical production segfault NOW.","Just fix typo in line 42."],
|
| 27 |
+
"research":["Research latest transformer advances."],
|
| 28 |
+
"document_drafting":["Draft project proposal for ML pipeline."],
|
| 29 |
+
"legal_regulated":["Review this contract for liability clauses."],
|
| 30 |
+
"tool_heavy":["Search open issues and create summary."],
|
| 31 |
+
"retrieval_heavy":["Answer based on 50-page document."],
|
| 32 |
+
"long_horizon":["Plan 3-month roadmap."],
|
| 33 |
+
"unknown_ambiguous":["Help me with this thing."],
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
classifier=TaskCostClassifier()
|
| 37 |
+
router=ModelCascadeRouter(model_path="/app/router_models/router_bundle_v8.pkl")
|
| 38 |
+
|
| 39 |
+
def sim_lps_correlated(tier, diff, success, rng):
|
| 40 |
+
"""Better simulation: successful outputs have higher confidence."""
|
| 41 |
+
n=rng.randint(20,150)
|
| 42 |
+
if success:
|
| 43 |
+
# High confidence: mean logprob close to 0
|
| 44 |
+
base={1:-1.5,2:-1.0,3:-0.6,4:-0.3,5:-0.2}[tier]
|
| 45 |
+
noise_std=0.3+diff*0.05
|
| 46 |
+
else:
|
| 47 |
+
# Low confidence: more uncertain, lower logprobs
|
| 48 |
+
base={1:-4.0,2:-3.0,3:-2.0,4:-1.5,5:-1.0}[tier]
|
| 49 |
+
noise_std=0.8+diff*0.2
|
| 50 |
+
return [base+rng.gauss(0,noise_std) for _ in range(n)]
|
| 51 |
+
|
| 52 |
+
N=3000
|
| 53 |
+
results = {}
|
| 54 |
+
rng = random.Random(42)
|
| 55 |
+
|
| 56 |
+
# Method 1: always frontier
|
| 57 |
+
rng.seed(42); succ=0; cost=0.0
|
| 58 |
+
for i in range(N):
|
| 59 |
+
tt=rng.choice(list(TASKS.keys())); req=rng.choice(TASKS[tt])
|
| 60 |
+
pred=classifier.classify(req)
|
| 61 |
+
ps=TIER_STR[4]**(pred["difficulty"]*0.6)
|
| 62 |
+
if rng.random()<ps: succ+=1
|
| 63 |
+
cost+=TIER_COST[4]
|
| 64 |
+
results["frontier"]={"success":succ/N,"cost":cost/N}
|
| 65 |
+
|
| 66 |
+
# Method 2: v8 router
|
| 67 |
+
rng.seed(42); succ=0; cost=0.0
|
| 68 |
+
for i in range(N):
|
| 69 |
+
tt=rng.choice(list(TASKS.keys())); req=rng.choice(TASKS[tt])
|
| 70 |
+
pred=classifier.classify(req)
|
| 71 |
+
r=router.route(req,tt,pred["difficulty"],pred)
|
| 72 |
+
ps=TIER_STR[r.tier]**(r.dynamic_difficulty*0.6)
|
| 73 |
+
if rng.random()<ps: succ+=1
|
| 74 |
+
cost+=TIER_COST[r.tier]
|
| 75 |
+
results["v8"]={"success":succ/N,"cost":cost/N}
|
| 76 |
+
|
| 77 |
+
# Method 3: v9 = v8 + feedback on non-frontier tiers
|
| 78 |
+
# Only use feedback when v8 selected tier < 4
|
| 79 |
+
for ent_thr in [2.0,2.5,3.0]:
|
| 80 |
+
for lc_thr in [0.05,0.10,0.15]:
|
| 81 |
+
ef=ExecutionFeedbackRouter(entropy_threshold=ent_thr,
|
| 82 |
+
low_conf_ratio_threshold=lc_thr, tier_costs=TIER_COST)
|
| 83 |
+
rng.seed(42); succ=0; cost=0.0; escalated_count=0
|
| 84 |
+
for i in range(N):
|
| 85 |
+
tt=rng.choice(list(TASKS.keys())); req=rng.choice(TASKS[tt])
|
| 86 |
+
pred=classifier.classify(req)
|
| 87 |
+
r=router.route(req,tt,pred["difficulty"],pred)
|
| 88 |
+
tier=r.tier; diff=r.dynamic_difficulty
|
| 89 |
+
# Only use feedback for non-frontier tiers
|
| 90 |
+
if tier < 4:
|
| 91 |
+
ps=TIER_STR[tier]**(diff*0.6)
|
| 92 |
+
initial_success=rng.random()<ps
|
| 93 |
+
lps=sim_lps_correlated(tier,diff,initial_success,rng)
|
| 94 |
+
signal=ef.analyze_output(lps,task_type=tt,current_tier=tier)
|
| 95 |
+
if signal.should_escalate:
|
| 96 |
+
final_tier=min(tier+1,5)
|
| 97 |
+
final_tier=max(final_tier,TASK_FLOOR.get(tt,1))
|
| 98 |
+
ps2=TIER_STR[final_tier]**(diff*0.6)
|
| 99 |
+
c=TIER_COST[tier]+TIER_COST[final_tier]
|
| 100 |
+
if rng.random()<ps2: succ+=1
|
| 101 |
+
escalated_count+=1
|
| 102 |
+
else:
|
| 103 |
+
c=TIER_COST[tier]
|
| 104 |
+
if initial_success: succ+=1
|
| 105 |
+
else:
|
| 106 |
+
ps=TIER_STR[tier]**(diff*0.6)
|
| 107 |
+
if rng.random()<ps: succ+=1
|
| 108 |
+
c=TIER_COST[tier]
|
| 109 |
+
cost+=c
|
| 110 |
+
name=f"v9_e{ent_thr}_lc{lc_thr}"
|
| 111 |
+
results[name]={"success":succ/N,"cost":cost/N,
|
| 112 |
+
"escalated":escalated_count,"ent_thr":ent_thr,"lc_thr":lc_thr}
|
| 113 |
+
|
| 114 |
+
print("="*80)
|
| 115 |
+
print("FINAL v9 RESULTS: v8 ROUTER + SELECTIVE FEEDBACK")
|
| 116 |
+
print("="*80)
|
| 117 |
+
fc=results["frontier"]["cost"]
|
| 118 |
+
print(f"\n{'Method':<25} {'Success':>10} {'AvgCost':>10} {'CostRed':>10} {'Gap':>8}")
|
| 119 |
+
print("-"*65)
|
| 120 |
+
print(f"{'always_frontier':<25} {results['frontier']['success']:>10.3f} {results['frontier']['cost']:>10.4f} {'0.0%':>10} {0.901-results['frontier']['success']:>8.3f}")
|
| 121 |
+
print(f"{'v8_router':<25} {results['v8']['success']:>10.3f} {results['v8']['cost']:>10.4f} {(1-results['v8']['cost']/fc)*100:>9.1f}% {0.901-results['v8']['success']:>8.3f}")
|
| 122 |
+
|
| 123 |
+
best_v9=None
|
| 124 |
+
for name,r in sorted(results.items(),key=lambda x:-x[1]["success"]):
|
| 125 |
+
if name.startswith("v9"):
|
| 126 |
+
cr=(1-r["cost"]/fc)*100
|
| 127 |
+
gap=0.901-r["success"]
|
| 128 |
+
print(f"{name:<25} {r['success']:>10.3f} {r['cost']:>10.4f} {cr:>9.1f}% {gap:>8.3f}")
|
| 129 |
+
# Pick the one closest to frontier quality with most savings
|
| 130 |
+
score=r["success"]*20-r["cost"]*5
|
| 131 |
+
if best_v9 is None or score>best_v9[0]:
|
| 132 |
+
best_v9=(score,name,r)
|
| 133 |
+
|
| 134 |
+
print(f"\n\nBest v9: {best_v9[1]}")
|
| 135 |
+
print(f" success={best_v9[2]['success']:.3f}, cost={best_v9[2]['cost']:.4f}")
|
| 136 |
+
print(f" costRed={(1-best_v9[2]['cost']/fc)*100:.1f}%")
|
| 137 |
+
print(f" quality_gap={0.901-best_v9[2]['success']:.3f}")
|
| 138 |
+
print(f" v8 gap={0.901-results['v8']['success']:.3f}")
|
| 139 |
+
print(f" Improvement: {best_v9[2]['success']-results['v8']['success']:+.3f} success, {best_v9[2]['cost']-results['v8']['cost']:+.4f} cost")
|
| 140 |
+
|
| 141 |
+
with open("/app/benchmark_v9_final.json","w") as f:
|
| 142 |
+
json.dump(results,f,indent=2,default=str)
|
| 143 |
+
print("DONE!")
|