narcolepticchicken commited on
Commit
7a843a8
·
verified ·
1 Parent(s): 44e6a9c

Upload eval/eval_bert_part3.py

Browse files
Files changed (1) hide show
  1. eval/eval_bert_part3.py +140 -0
eval/eval_bert_part3.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ── Part 3: Run evaluations ──
3
+
4
+ def route_bert(problem_text):
5
+ """Route using BERT classifier."""
6
+ inputs = tokenizer(problem_text, truncation=True, max_length=512, return_tensors="pt")
7
+ with torch.no_grad():
8
+ logits = bert_model(**inputs).logits
9
+ pred_class = torch.argmax(logits, dim=-1).item()
10
+ tier = pred_class + 1 # classes are 0-4 → tiers 1-5
11
+ probs = torch.softmax(logits, dim=-1)[0]
12
+ confidence = float(probs[pred_class])
13
+ return tier, confidence
14
+
15
+ def route_v11(problem_text):
16
+ """Route using v11 XGBoost + isotonic calibration."""
17
+ feats = extract_features(problem_text)
18
+ feat_vec = np.array([float(feats.get(k, 0.0)) for k in v11_feat_keys], dtype=np.float32).reshape(1,-1)
19
+ tier_probs = {}
20
+ for t in range(1, 6):
21
+ p_raw = v11_tier_clfs[t].predict_proba(feat_vec)[0, 1]
22
+ p_cal = float(v11_tier_calibs[t].transform([p_raw])[0])
23
+ tier_probs[t] = p_cal
24
+ # Find cheapest tier with P(success) >= 0.65
25
+ for t in range(1, 6):
26
+ if tier_probs[t] >= 0.65:
27
+ return t, tier_probs[t], tier_probs
28
+ return 5, tier_probs[5], tier_probs
29
+
30
+ TASK_FLOOR = {"quick_answer":1,"coding":3,"research":3,"document_drafting":2,
31
+ "legal_regulated":4,"tool_heavy":2,"retrieval_heavy":2,"long_horizon":3,"unknown_ambiguous":3}
32
+
33
+ def classify_task(text):
34
+ r = text.lower()
35
+ if any(k in r for k in ["contract","legal","compliance","gdpr","privacy"]): return "legal_regulated"
36
+ if any(k in r for k in ["debug","fix","bug","implement","refactor","code","function"]): return "coding"
37
+ if any(k in r for k in ["research","find sources","literature","investigate"]): return "research"
38
+ if any(k in r for k in ["search","fetch","api","query","database"]): return "tool_heavy"
39
+ if any(k in r for k in ["plan","roadmap","orchestrate","migrate","deploy"]): return "long_horizon"
40
+ if any(k in r for k in ["draft","write","compose","document"]): return "document_drafting"
41
+ if any(k in r for k in ["what is","explain","define","briefly"]): return "quick_answer"
42
+ return "unknown_ambiguous"
43
+
44
+ policies = defaultdict(lambda: {"success":0,"cost":0.0,"n":0})
45
+
46
+ print("\n[4] Evaluating all policies on SWE-Router...")
47
+ for iid, model_results in traces.items():
48
+ problem = next(iter(model_results.values()))['problem']
49
+ task_type = classify_task(problem)
50
+ floor = TASK_FLOOR.get(task_type, 2)
51
+
52
+ # Oracle
53
+ resolved = [(m, r) for m, r in model_results.items() if r['resolved']]
54
+ if resolved:
55
+ cheapest = min(resolved, key=lambda x: TIER_COST.get(MODEL_TIER[x[0]], 1.0))
56
+ policies['oracle']['success'] += 1
57
+ policies['oracle']['cost'] += cheapest[1]['cost']
58
+ else:
59
+ policies['oracle']['cost'] += min(r['cost'] for r in model_results.values())
60
+ policies['oracle']['n'] += 1
61
+
62
+ # Always frontier
63
+ f_model = 'claude-opus-4.7'
64
+ if f_model in model_results:
65
+ policies['frontier']['success'] += int(model_results[f_model]['resolved'])
66
+ policies['frontier']['cost'] += model_results[f_model]['cost']
67
+ policies['frontier']['n'] += 1
68
+
69
+ # BERT router
70
+ bert_tier, bert_conf = route_bert(problem)
71
+ bert_tier = max(bert_tier, floor) # enforce safety floor
72
+ m_bert = TIER_TO_SWE.get(bert_tier, 'claude-opus-4.7')
73
+ if m_bert in model_results:
74
+ policies['bert']['success'] += int(model_results[m_bert]['resolved'])
75
+ policies['bert']['cost'] += model_results[m_bert]['cost']
76
+ else:
77
+ policies['bert']['success'] += int(model_results.get('claude-opus-4.7',{}).get('resolved',0))
78
+ policies['bert']['cost'] += model_results.get('claude-opus-4.7',{}).get('cost',0.3)
79
+ policies['bert']['n'] += 1
80
+
81
+ # v11 XGBoost
82
+ v11_tier, v11_conf, v11_probs = route_v11(problem)
83
+ v11_tier = max(v11_tier, floor) # enforce safety floor
84
+ m_v11 = TIER_TO_SWE.get(v11_tier, 'claude-opus-4.7')
85
+ if m_v11 in model_results:
86
+ policies['v11_xgboost']['success'] += int(model_results[m_v11]['resolved'])
87
+ policies['v11_xgboost']['cost'] += model_results[m_v11]['cost']
88
+ else:
89
+ policies['v11_xgboost']['success'] += int(model_results.get('claude-opus-4.7',{}).get('resolved',0))
90
+ policies['v11_xgboost']['cost'] += model_results.get('claude-opus-4.7',{}).get('cost',0.3)
91
+ policies['v11_xgboost']['n'] += 1
92
+
93
+ # BERT + feedback (escalate on failure)
94
+ if m_bert in model_results and model_results[m_bert]['resolved']:
95
+ policies['bert_feedback']['success'] += 1
96
+ policies['bert_feedback']['cost'] += model_results[m_bert]['cost']
97
+ else:
98
+ # Escalate one tier
99
+ up_tier = min(bert_tier + 1, 5)
100
+ m_up = TIER_TO_SWE.get(up_tier, 'claude-opus-4.7')
101
+ if m_up in model_results and model_results[m_up]['resolved']:
102
+ policies['bert_feedback']['success'] += 1
103
+ policies['bert_feedback']['cost'] += model_results.get(m_bert,{}).get('cost',0.01)
104
+ policies['bert_feedback']['cost'] += model_results[m_up]['cost']
105
+ else:
106
+ # Last resort: frontier
107
+ if f_model in model_results and model_results[f_model]['resolved']:
108
+ policies['bert_feedback']['success'] += 1
109
+ policies['bert_feedback']['cost'] += model_results.get(m_bert,{}).get('cost',0.01)
110
+ policies['bert_feedback']['cost'] += model_results[f_model]['cost']
111
+ else:
112
+ policies['bert_feedback']['cost'] += model_results.get(m_bert,{}).get('cost',0.01)
113
+ policies['bert_feedback']['n'] += 1
114
+
115
+ # v11 + feedback (escalate on failure)
116
+ if m_v11 in model_results and model_results[m_v11]['resolved']:
117
+ policies['v11_feedback']['success'] += 1
118
+ policies['v11_feedback']['cost'] += model_results[m_v11]['cost']
119
+ else:
120
+ up_tier = min(v11_tier + 1, 5)
121
+ m_up = TIER_TO_SWE.get(up_tier, 'claude-opus-4.7')
122
+ if m_up in model_results and model_results[m_up]['resolved']:
123
+ policies['v11_feedback']['success'] += 1
124
+ policies['v11_feedback']['cost'] += model_results.get(m_v11,{}).get('cost',0.01)
125
+ policies['v11_feedback']['cost'] += model_results[m_up]['cost']
126
+ else:
127
+ if f_model in model_results and model_results[f_model]['resolved']:
128
+ policies['v11_feedback']['success'] += 1
129
+ policies['v11_feedback']['cost'] += model_results.get(m_v11,{}).get('cost',0.01)
130
+ policies['v11_feedback']['cost'] += model_results[f_model]['cost']
131
+ else:
132
+ policies['v11_feedback']['cost'] += model_results.get(m_v11,{}).get('cost',0.01)
133
+ policies['v11_feedback']['n'] += 1
134
+
135
+ # Always cheap
136
+ c_model = 'deepseek-v4-flash'
137
+ if c_model in model_results:
138
+ policies['always_cheap']['success'] += int(model_results[c_model]['resolved'])
139
+ policies['always_cheap']['cost'] += model_results[c_model]['cost']
140
+ policies['always_cheap']['n'] += 1