narcolepticchicken commited on
Commit
a22b742
·
verified ·
1 Parent(s): 6433f0f

Upload eval/eval_bert_partB.py

Browse files
Files changed (1) hide show
  1. eval/eval_bert_partB.py +44 -0
eval/eval_bert_partB.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ── Load models ──
3
+ print("\n[2] Loading BERT router from Hub...")
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+
7
+ REPO = "narcolepticchicken/agent-cost-optimizer"
8
+ tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="router_models/bert_router")
9
+ bert_model = AutoModelForSequenceClassification.from_pretrained(REPO, subfolder="router_models/bert_router")
10
+ bert_model.eval()
11
+ print(f" BERT loaded, num_labels={bert_model.config.num_labels}")
12
+
13
+ print("\n[3] Loading v11 XGBoost router...")
14
+ from huggingface_hub import hf_hub_download
15
+ v11_path = hf_hub_download(REPO, "router_models/router_bundle_v11.pkl")
16
+ v11_bundle = pickle.load(open(v11_path, "rb"))
17
+ v11_tier_clfs = {int(k):v for k,v in v11_bundle["tier_clfs"].items()}
18
+ v11_tier_calibs = {int(k):v for k,v in v11_bundle["tier_calibrators"].items()}
19
+ v11_feat_keys = v11_bundle["feat_keys"]
20
+ print(f" v11 loaded, features={len(v11_feat_keys)}")
21
+
22
+ # ── Routing functions ──
23
+ def route_bert(problem_text):
24
+ inputs = tokenizer(problem_text, truncation=True, max_length=512, return_tensors="pt")
25
+ with torch.no_grad():
26
+ logits = bert_model(**inputs).logits
27
+ pred_class = torch.argmax(logits, dim=-1).item()
28
+ tier = pred_class + 1
29
+ probs = torch.softmax(logits, dim=-1)[0]
30
+ confidence = float(probs[pred_class])
31
+ return tier, confidence
32
+
33
+ def route_v11(problem_text):
34
+ feats = extract_features(problem_text)
35
+ feat_vec = np.array([float(feats.get(k, 0.0)) for k in v11_feat_keys], dtype=np.float32).reshape(1,-1)
36
+ tier_probs = {}
37
+ for t in range(1, 6):
38
+ p_raw = v11_tier_clfs[t].predict_proba(feat_vec)[0, 1]
39
+ p_cal = float(v11_tier_calibs[t].transform([p_raw])[0])
40
+ tier_probs[t] = p_cal
41
+ for t in range(1, 6):
42
+ if tier_probs[t] >= 0.65:
43
+ return t, tier_probs[t], tier_probs
44
+ return 5, tier_probs[5], tier_probs