Upload train_router.py
Browse files- train_router.py +77 -0
train_router.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Train a learned model router for Agent Cost Optimizer.
|
| 3 |
+
|
| 4 |
+
Architecture: CARROT-style plug-in router.
|
| 5 |
+
- Per-tier P(success|query) classifiers (XGBoost)
|
| 6 |
+
- Per-tier cost estimators
|
| 7 |
+
- Route to cheapest tier where P(success) > threshold
|
| 8 |
+
|
| 9 |
+
Ground truth: optimal_tier = min{tier : success(tier)} from execution traces.
|
| 10 |
+
"""
|
| 11 |
+
import json, os, sys, random, pickle, uuid
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime, timedelta
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from typing import Dict, List, Tuple, Any, Optional
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from enum import Enum
|
| 18 |
+
|
| 19 |
+
# βββ Feature Extraction βββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
TASK_TYPES = [
|
| 21 |
+
"quick_answer", "coding", "research", "document_drafting",
|
| 22 |
+
"legal_regulated", "tool_heavy", "retrieval_heavy",
|
| 23 |
+
"long_horizon", "unknown_ambiguous",
|
| 24 |
+
]
|
| 25 |
+
TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)}
|
| 26 |
+
|
| 27 |
+
CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
|
| 28 |
+
"implement","test","compile","runtime","class","module","package",
|
| 29 |
+
"async","thread","queue","stack","heap","pointer","segfault","linter"]
|
| 30 |
+
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy",
|
| 31 |
+
"regulatory","liability","clause","indemnification","tos"]
|
| 32 |
+
RESEARCH_KW = ["research","find sources","literature","investigate","compare",
|
| 33 |
+
"analyze","study","survey","paper","arxiv","citation"]
|
| 34 |
+
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape",
|
| 35 |
+
"lookup","download","upload","index","aggregate"]
|
| 36 |
+
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate",
|
| 37 |
+
"pipeline","end-to-end","architecture","workflow","deploy"]
|
| 38 |
+
MATH_KW = ["calculate","compute","solve","equation","formula","optimize",
|
| 39 |
+
"probability","integral","derivative","matrix"]
|
| 40 |
+
|
| 41 |
+
def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]:
|
| 42 |
+
r = request.lower()
|
| 43 |
+
feats = {
|
| 44 |
+
"request_length": len(request),
|
| 45 |
+
"num_words": len(request.split()),
|
| 46 |
+
"num_sentences": request.count(".") + request.count("!") + request.count("?"),
|
| 47 |
+
"has_code_kw": int(any(kw in r for kw in CODE_KW)),
|
| 48 |
+
"num_code_kw": sum(1 for kw in CODE_KW if kw in r),
|
| 49 |
+
"has_legal_kw": int(any(kw in r for kw in LEGAL_KW)),
|
| 50 |
+
"num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r),
|
| 51 |
+
"has_research_kw": int(any(kw in r for kw in RESEARCH_KW)),
|
| 52 |
+
"num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r),
|
| 53 |
+
"has_tool_kw": int(any(kw in r for kw in TOOL_KW)),
|
| 54 |
+
"num_tool_kw": sum(1 for kw in TOOL_KW if kw in r),
|
| 55 |
+
"has_long_kw": int(any(kw in r for kw in LONG_KW)),
|
| 56 |
+
"has_math_kw": int(any(kw in r for kw in MATH_KW)),
|
| 57 |
+
"task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8),
|
| 58 |
+
}
|
| 59 |
+
# One-hot task type
|
| 60 |
+
for tt in TASK_TYPES:
|
| 61 |
+
feats[f"tt_{tt}"] = int(task_type == tt)
|
| 62 |
+
if metadata:
|
| 63 |
+
feats["difficulty"] = metadata.get("difficulty", 3)
|
| 64 |
+
return feats
|
| 65 |
+
|
| 66 |
+
def feats_to_array(feats: Dict) -> List[float]:
|
| 67 |
+
"""Convert feature dict to fixed-order array."""
|
| 68 |
+
keys = sorted(feats.keys())
|
| 69 |
+
return [float(feats[k]) for k in keys]
|
| 70 |
+
|
| 71 |
+
FEAT_KEYS = None # set on first call
|
| 72 |
+
|
| 73 |
+
def feats_to_array_safe(feats: Dict) -> List[float]:
|
| 74 |
+
global FEAT_KEYS
|
| 75 |
+
if FEAT_KEYS is None:
|
| 76 |
+
FEAT_KEYS = sorted(feats.keys())
|
| 77 |
+
return [float(feats.get(k, 0.0)) for k in FEAT_KEYS]
|