agent-cost-optimizer / train_router.py
narcolepticchicken's picture
Upload train_router.py
8275b23 verified
#!/usr/bin/env python3
"""Train a learned model router for Agent Cost Optimizer.
Architecture: CARROT-style plug-in router.
- Per-tier P(success|query) classifiers (XGBoost)
- Per-tier cost estimators
- Route to cheapest tier where P(success) > threshold
Ground truth: optimal_tier = min{tier : success(tier)} from execution traces.
"""
import json, os, sys, random, pickle, uuid
import numpy as np
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from enum import Enum
# ─── Feature Extraction ───────────────────────────────────────────
TASK_TYPES = [
"quick_answer", "coding", "research", "document_drafting",
"legal_regulated", "tool_heavy", "retrieval_heavy",
"long_horizon", "unknown_ambiguous",
]
TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)}
CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
"implement","test","compile","runtime","class","module","package",
"async","thread","queue","stack","heap","pointer","segfault","linter"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy",
"regulatory","liability","clause","indemnification","tos"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare",
"analyze","study","survey","paper","arxiv","citation"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape",
"lookup","download","upload","index","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate",
"pipeline","end-to-end","architecture","workflow","deploy"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize",
"probability","integral","derivative","matrix"]
def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]:
r = request.lower()
feats = {
"request_length": len(request),
"num_words": len(request.split()),
"num_sentences": request.count(".") + request.count("!") + request.count("?"),
"has_code_kw": int(any(kw in r for kw in CODE_KW)),
"num_code_kw": sum(1 for kw in CODE_KW if kw in r),
"has_legal_kw": int(any(kw in r for kw in LEGAL_KW)),
"num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r),
"has_research_kw": int(any(kw in r for kw in RESEARCH_KW)),
"num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r),
"has_tool_kw": int(any(kw in r for kw in TOOL_KW)),
"num_tool_kw": sum(1 for kw in TOOL_KW if kw in r),
"has_long_kw": int(any(kw in r for kw in LONG_KW)),
"has_math_kw": int(any(kw in r for kw in MATH_KW)),
"task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8),
}
# One-hot task type
for tt in TASK_TYPES:
feats[f"tt_{tt}"] = int(task_type == tt)
if metadata:
feats["difficulty"] = metadata.get("difficulty", 3)
return feats
def feats_to_array(feats: Dict) -> List[float]:
"""Convert feature dict to fixed-order array."""
keys = sorted(feats.keys())
return [float(feats[k]) for k in keys]
FEAT_KEYS = None # set on first call
def feats_to_array_safe(feats: Dict) -> List[float]:
global FEAT_KEYS
if FEAT_KEYS is None:
FEAT_KEYS = sorted(feats.keys())
return [float(feats.get(k, 0.0)) for k in FEAT_KEYS]