File size: 2,837 Bytes
a22b742
 
 
 
 
 
 
 
 
 
 
 
3bed4ff
 
 
 
a22b742
3bed4ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22b742
 
 
3bed4ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a22b742
3bed4ff
 
 
 
a22b742
3bed4ff
a22b742
 
3bed4ff
 
a22b742
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

# ── Load models ──
print("\n[2] Loading BERT router from Hub...")
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

REPO = "narcolepticchicken/agent-cost-optimizer"
tokenizer = AutoTokenizer.from_pretrained(REPO, subfolder="router_models/bert_router")
bert_model = AutoModelForSequenceClassification.from_pretrained(REPO, subfolder="router_models/bert_router")
bert_model.eval()
print(f"  BERT loaded, num_labels={bert_model.config.num_labels}")

# BERT is binary (success/fail) β€” we'll use it as a per-tier success predictor
# by prepending "Tier X:" to the input text

print("\n[3] Loading v10 XGBoost router...")
from huggingface_hub import hf_hub_download
import pickle
v10_path = hf_hub_download(REPO, "router_models/router_bundle_v10_fixed.pkl")
v10_bundle = pickle.load(open(v10_path, "rb"))
print(f"  v10 bundle keys: {list(v10_bundle.keys())}")

# v10 may have different structure β€” inspect
if "tier_clfs" in v10_bundle:
    v10_tier_clfs = {int(k):v for k,v in v10_bundle["tier_clfs"].items()}
    v10_tier_calibs = {int(k):v for k,v in v10_bundle["tier_calibrators"].items()}
    v10_feat_keys = v10_bundle["feat_keys"]
    print(f"  v10 loaded, features={len(v10_feat_keys)}")
    HAS_V10 = True
else:
    HAS_V10 = False
    print(f"  v10 bundle structure: {type(v10_bundle)}")

# ── Routing functions ──
def route_bert(problem_text):
    """BERT binary classifier: predict success probability at each tier.
    Route to cheapest tier where P(success) > 0.5."""
    tier_probs = {}
    for tier in range(1, 6):
        prompt = f"[Tier {tier}] {problem_text}"
        inputs = tokenizer(prompt, truncation=True, max_length=512, return_tensors="pt")
        with torch.no_grad():
            logits = bert_model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)[0]
        # Binary: class 1 = success
        tier_probs[tier] = float(probs[1]) if bert_model.config.num_labels == 2 else float(probs[tier-1])
    
    # Route to cheapest tier with P(success) >= 0.5
    for t in range(1, 6):
        if tier_probs[t] >= 0.5:
            return t, tier_probs[t], tier_probs
    return 5, tier_probs[5], tier_probs

def route_v10(problem_text):
    """v10 XGBoost cascade router."""
    if not HAS_V10:
        return 4, 0.5, {t:0.5 for t in range(1,6)}
    feats = extract_features(problem_text)
    feat_vec = np.array([float(feats.get(k, 0.0)) for k in v10_feat_keys], dtype=np.float32).reshape(1,-1)
    tier_probs = {}
    for t in range(1, 6):
        p_raw = v10_tier_clfs[t].predict_proba(feat_vec)[0, 1]
        p_cal = float(v10_tier_calibs[t].transform([p_raw])[0])
        tier_probs[t] = p_cal
    for t in range(1, 6):
        if tier_probs[t] >= 0.65:
            return t, tier_probs[t], tier_probs
    return 5, tier_probs[5], tier_probs