narcolepticchicken commited on
Commit
3bed4ff
·
verified ·
1 Parent(s): e22f616

Upload eval/eval_bert_partB.py

Browse files
Files changed (1) hide show
  1. eval/eval_bert_partB.py +43 -19
eval/eval_bert_partB.py CHANGED
@@ -10,33 +10,57 @@ bert_model = AutoModelForSequenceClassification.from_pretrained(REPO, subfolder=
10
  bert_model.eval()
11
  print(f" BERT loaded, num_labels={bert_model.config.num_labels}")
12
 
13
- print("\n[3] Loading v11 XGBoost router...")
 
 
 
14
  from huggingface_hub import hf_hub_download
15
- v11_path = hf_hub_download(REPO, "router_models/router_bundle_v11.pkl")
16
- v11_bundle = pickle.load(open(v11_path, "rb"))
17
- v11_tier_clfs = {int(k):v for k,v in v11_bundle["tier_clfs"].items()}
18
- v11_tier_calibs = {int(k):v for k,v in v11_bundle["tier_calibrators"].items()}
19
- v11_feat_keys = v11_bundle["feat_keys"]
20
- print(f" v11 loaded, features={len(v11_feat_keys)}")
 
 
 
 
 
 
 
 
 
21
 
22
  # ── Routing functions ──
23
  def route_bert(problem_text):
24
- inputs = tokenizer(problem_text, truncation=True, max_length=512, return_tensors="pt")
25
- with torch.no_grad():
26
- logits = bert_model(**inputs).logits
27
- pred_class = torch.argmax(logits, dim=-1).item()
28
- tier = pred_class + 1
29
- probs = torch.softmax(logits, dim=-1)[0]
30
- confidence = float(probs[pred_class])
31
- return tier, confidence
 
 
 
 
 
 
 
 
 
32
 
33
- def route_v11(problem_text):
 
 
 
34
  feats = extract_features(problem_text)
35
- feat_vec = np.array([float(feats.get(k, 0.0)) for k in v11_feat_keys], dtype=np.float32).reshape(1,-1)
36
  tier_probs = {}
37
  for t in range(1, 6):
38
- p_raw = v11_tier_clfs[t].predict_proba(feat_vec)[0, 1]
39
- p_cal = float(v11_tier_calibs[t].transform([p_raw])[0])
40
  tier_probs[t] = p_cal
41
  for t in range(1, 6):
42
  if tier_probs[t] >= 0.65:
 
10
  bert_model.eval()
11
  print(f" BERT loaded, num_labels={bert_model.config.num_labels}")
12
 
13
+ # BERT is binary (success/fail) — we'll use it as a per-tier success predictor
14
+ # by prepending "Tier X:" to the input text
15
+
16
+ print("\n[3] Loading v10 XGBoost router...")
17
  from huggingface_hub import hf_hub_download
18
+ import pickle
19
+ v10_path = hf_hub_download(REPO, "router_models/router_bundle_v10_fixed.pkl")
20
+ v10_bundle = pickle.load(open(v10_path, "rb"))
21
+ print(f" v10 bundle keys: {list(v10_bundle.keys())}")
22
+
23
+ # v10 may have different structure — inspect
24
+ if "tier_clfs" in v10_bundle:
25
+ v10_tier_clfs = {int(k):v for k,v in v10_bundle["tier_clfs"].items()}
26
+ v10_tier_calibs = {int(k):v for k,v in v10_bundle["tier_calibrators"].items()}
27
+ v10_feat_keys = v10_bundle["feat_keys"]
28
+ print(f" v10 loaded, features={len(v10_feat_keys)}")
29
+ HAS_V10 = True
30
+ else:
31
+ HAS_V10 = False
32
+ print(f" v10 bundle structure: {type(v10_bundle)}")
33
 
34
  # ── Routing functions ──
35
  def route_bert(problem_text):
36
+ """BERT binary classifier: predict success probability at each tier.
37
+ Route to cheapest tier where P(success) > 0.5."""
38
+ tier_probs = {}
39
+ for tier in range(1, 6):
40
+ prompt = f"[Tier {tier}] {problem_text}"
41
+ inputs = tokenizer(prompt, truncation=True, max_length=512, return_tensors="pt")
42
+ with torch.no_grad():
43
+ logits = bert_model(**inputs).logits
44
+ probs = torch.softmax(logits, dim=-1)[0]
45
+ # Binary: class 1 = success
46
+ tier_probs[tier] = float(probs[1]) if bert_model.config.num_labels == 2 else float(probs[tier-1])
47
+
48
+ # Route to cheapest tier with P(success) >= 0.5
49
+ for t in range(1, 6):
50
+ if tier_probs[t] >= 0.5:
51
+ return t, tier_probs[t], tier_probs
52
+ return 5, tier_probs[5], tier_probs
53
 
54
+ def route_v10(problem_text):
55
+ """v10 XGBoost cascade router."""
56
+ if not HAS_V10:
57
+ return 4, 0.5, {t:0.5 for t in range(1,6)}
58
  feats = extract_features(problem_text)
59
+ feat_vec = np.array([float(feats.get(k, 0.0)) for k in v10_feat_keys], dtype=np.float32).reshape(1,-1)
60
  tier_probs = {}
61
  for t in range(1, 6):
62
+ p_raw = v10_tier_clfs[t].predict_proba(feat_vec)[0, 1]
63
+ p_cal = float(v10_tier_calibs[t].transform([p_raw])[0])
64
  tier_probs[t] = p_cal
65
  for t in range(1, 6):
66
  if tier_probs[t] >= 0.65: