agent-cost-optimizer / eval /eval_bert_partB.py
narcolepticchicken's picture
Upload eval/eval_bert_partB.py
3bed4ff verified
# ── Load models ──
print("\n[2] Loading BERT router from Hub...")
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
REPO = "narcolepticchicken/agent-cost-optimizer"
tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="router_models/bert_router")
bert_model = AutoModelForSequenceClassification.from_pretrained(REPO, subfolder="router_models/bert_router")
bert_model.eval()
print(f" BERT loaded, num_labels={bert_model.config.num_labels}")
# BERT is binary (success/fail) β€” we'll use it as a per-tier success predictor
# by prepending "Tier X:" to the input text
print("\n[3] Loading v10 XGBoost router...")
from huggingface_hub import hf_hub_download
import pickle
v10_path = hf_hub_download(REPO, "router_models/router_bundle_v10_fixed.pkl")
v10_bundle = pickle.load(open(v10_path, "rb"))
print(f" v10 bundle keys: {list(v10_bundle.keys())}")
# v10 may have different structure β€” inspect
if "tier_clfs" in v10_bundle:
v10_tier_clfs = {int(k):v for k,v in v10_bundle["tier_clfs"].items()}
v10_tier_calibs = {int(k):v for k,v in v10_bundle["tier_calibrators"].items()}
v10_feat_keys = v10_bundle["feat_keys"]
print(f" v10 loaded, features={len(v10_feat_keys)}")
HAS_V10 = True
else:
HAS_V10 = False
print(f" v10 bundle structure: {type(v10_bundle)}")
# ── Routing functions ──
def route_bert(problem_text):
"""BERT binary classifier: predict success probability at each tier.
Route to cheapest tier where P(success) > 0.5."""
tier_probs = {}
for tier in range(1, 6):
prompt = f"[Tier {tier}] {problem_text}"
inputs = tokenizer(prompt, truncation=True, max_length=512, return_tensors="pt")
with torch.no_grad():
logits = bert_model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
# Binary: class 1 = success
tier_probs[tier] = float(probs[1]) if bert_model.config.num_labels == 2 else float(probs[tier-1])
# Route to cheapest tier with P(success) >= 0.5
for t in range(1, 6):
if tier_probs[t] >= 0.5:
return t, tier_probs[t], tier_probs
return 5, tier_probs[5], tier_probs
def route_v10(problem_text):
"""v10 XGBoost cascade router."""
if not HAS_V10:
return 4, 0.5, {t:0.5 for t in range(1,6)}
feats = extract_features(problem_text)
feat_vec = np.array([float(feats.get(k, 0.0)) for k in v10_feat_keys], dtype=np.float32).reshape(1,-1)
tier_probs = {}
for t in range(1, 6):
p_raw = v10_tier_clfs[t].predict_proba(feat_vec)[0, 1]
p_cal = float(v10_tier_calibs[t].transform([p_raw])[0])
tier_probs[t] = p_cal
for t in range(1, 6):
if tier_probs[t] >= 0.65:
return t, tier_probs[t], tier_probs
return 5, tier_probs[5], tier_probs