| |
| """Train a learned model router for Agent Cost Optimizer. |
| |
| Architecture: CARROT-style plug-in router. |
| - Per-tier P(success|query) classifiers (XGBoost) |
| - Per-tier cost estimators |
| - Route to cheapest tier where P(success) > threshold |
| |
| Ground truth: optimal_tier = min{tier : success(tier)} from execution traces. |
| """ |
| import json, os, sys, random, pickle, uuid |
| import numpy as np |
| from datetime import datetime, timedelta |
| from collections import defaultdict |
| from typing import Dict, List, Tuple, Any, Optional |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
| |
| TASK_TYPES = [ |
| "quick_answer", "coding", "research", "document_drafting", |
| "legal_regulated", "tool_heavy", "retrieval_heavy", |
| "long_horizon", "unknown_ambiguous", |
| ] |
| TASK_TYPE_TO_IDX = {t: i for i, t in enumerate(TASK_TYPES)} |
|
|
| CODE_KW = ["python","javascript","code","function","bug","debug","refactor", |
| "implement","test","compile","runtime","class","module","package", |
| "async","thread","queue","stack","heap","pointer","segfault","linter"] |
| LEGAL_KW = ["contract","legal","compliance","gdpr","privacy","policy", |
| "regulatory","liability","clause","indemnification","tos"] |
| RESEARCH_KW = ["research","find sources","literature","investigate","compare", |
| "analyze","study","survey","paper","arxiv","citation"] |
| TOOL_KW = ["search","fetch","retrieve","query","api","database","scrape", |
| "lookup","download","upload","index","aggregate"] |
| LONG_KW = ["plan","project","roadmap","orchestrate","multi-step","migrate", |
| "pipeline","end-to-end","architecture","workflow","deploy"] |
| MATH_KW = ["calculate","compute","solve","equation","formula","optimize", |
| "probability","integral","derivative","matrix"] |
|
|
| def extract_features(request: str, task_type: str, metadata: Dict = None) -> Dict[str, Any]: |
| r = request.lower() |
| feats = { |
| "request_length": len(request), |
| "num_words": len(request.split()), |
| "num_sentences": request.count(".") + request.count("!") + request.count("?"), |
| "has_code_kw": int(any(kw in r for kw in CODE_KW)), |
| "num_code_kw": sum(1 for kw in CODE_KW if kw in r), |
| "has_legal_kw": int(any(kw in r for kw in LEGAL_KW)), |
| "num_legal_kw": sum(1 for kw in LEGAL_KW if kw in r), |
| "has_research_kw": int(any(kw in r for kw in RESEARCH_KW)), |
| "num_research_kw": sum(1 for kw in RESEARCH_KW if kw in r), |
| "has_tool_kw": int(any(kw in r for kw in TOOL_KW)), |
| "num_tool_kw": sum(1 for kw in TOOL_KW if kw in r), |
| "has_long_kw": int(any(kw in r for kw in LONG_KW)), |
| "has_math_kw": int(any(kw in r for kw in MATH_KW)), |
| "task_type_idx": TASK_TYPE_TO_IDX.get(task_type, 8), |
| } |
| |
| for tt in TASK_TYPES: |
| feats[f"tt_{tt}"] = int(task_type == tt) |
| if metadata: |
| feats["difficulty"] = metadata.get("difficulty", 3) |
| return feats |
|
|
| def feats_to_array(feats: Dict) -> List[float]: |
| """Convert feature dict to fixed-order array.""" |
| keys = sorted(feats.keys()) |
| return [float(feats[k]) for k in keys] |
|
|
| FEAT_KEYS = None |
|
|
| def feats_to_array_safe(feats: Dict) -> List[float]: |
| global FEAT_KEYS |
| if FEAT_KEYS is None: |
| FEAT_KEYS = sorted(feats.keys()) |
| return [float(feats.get(k, 0.0)) for k in FEAT_KEYS] |
|
|