File size: 3,526 Bytes
8275b23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | #!/usr/bin/env python3
"""Train a learned model router for Agent Cost Optimizer.
Architecture: CARROT-style plug-in router.
- Per-tier P(success|query) classifiers (XGBoost)
- Per-tier cost estimators
- Route to cheapest tier where P(success) > threshold
Ground truth: optimal_tier = min{tier : success(tier)} from execution traces.
"""
import json, os, sys, random, pickle, uuid
import numpy as np
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from enum import Enum
# βββ Feature Extraction βββββββββββββββββββββββββββββββββββββββββββ
TASK_TYPES = [
"quick_answer", "coding", "research", "document_drafting",
"legal_regulated", "tool_heavy", "retrieval_heavy",
"long_horizon", "unknown_ambiguous",
]
TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)}
CODE_KW = ["python","javascript","code","function","bug","debug","refactor",
"implement","test","compile","runtime","class","module","package",
"async","thread","queue","stack","heap","pointer","segfault","linter"]
LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy",
"regulatory","liability","clause","indemnification","tos"]
RESEARCH_KW = ["research","find sources","literature","investigate","compare",
"analyze","study","survey","paper","arxiv","citation"]
TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape",
"lookup","download","upload","index","aggregate"]
LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate",
"pipeline","end-to-end","architecture","workflow","deploy"]
MATH_KW = ["calculate","compute","solve","equation","formula","optimize",
"probability","integral","derivative","matrix"]
def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]:
r = request.lower()
feats = {
"request_length": len(request),
"num_words": len(request.split()),
"num_sentences": request.count(".") + request.count("!") + request.count("?"),
"has_code_kw": int(any(kw in r for kw in CODE_KW)),
"num_code_kw": sum(1 for kw in CODE_KW if kw in r),
"has_legal_kw": int(any(kw in r for kw in LEGAL_KW)),
"num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r),
"has_research_kw": int(any(kw in r for kw in RESEARCH_KW)),
"num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r),
"has_tool_kw": int(any(kw in r for kw in TOOL_KW)),
"num_tool_kw": sum(1 for kw in TOOL_KW if kw in r),
"has_long_kw": int(any(kw in r for kw in LONG_KW)),
"has_math_kw": int(any(kw in r for kw in MATH_KW)),
"task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8),
}
# One-hot task type
for tt in TASK_TYPES:
feats[f"tt_{tt}"] = int(task_type == tt)
if metadata:
feats["difficulty"] = metadata.get("difficulty", 3)
return feats
def feats_to_array(feats: Dict) -> List[float]:
"""Convert feature dict to fixed-order array."""
keys = sorted(feats.keys())
return [float(feats[k]) for k in keys]
FEAT_KEYS = None # set on first call
def feats_to_array_safe(feats: Dict) -> List[float]:
global FEAT_KEYS
if FEAT_KEYS is None:
FEAT_KEYS = sorted(feats.keys())
return [float(feats.get(k, 0.0)) for k in FEAT_KEYS]
|